diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..47e16e0b68cbe074529a69e90cb1f570212c1f44 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text +*.metallib filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8dbea71b35c91f8f7f7b235ddcc25c60b7234f32 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +--- +license: apache-2.0 +tags: + - kernels +--- + +![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/paged-attention) + +## attention + +Paged attention kernels from [vLLM](https://github.com/vllm-project/) and [mistral.rs](https://github.com/EricLBuehler/mistral.rs). + diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..21bae1407630376967edda094e374a63505a43b7 --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,263 @@ +import torch + +from kernels.benchmark import Benchmark + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + # query: (q, h, d), key: (k, h, d), value: (k, h, d) + # Transpose to (h, q, d) and (h, k, d) for batched matmul + q = query.transpose(0, 1) # (h, q, d) + k = key.transpose(0, 1) # (h, k, d) + v = value.transpose(0, 1) # (h, k, d) + + # Compute attention scores: (h, q, d) @ (h, d, k) -> (h, q, k) + attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + + # Compute output: (h, q, k) @ (h, k, d) -> (h, q, d) + out = torch.matmul(attn_weights, v) + + # Transpose back to (q, h, d) + return out.transpose(0, 1) + + +def ref_paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, +) -> torch.Tensor: + num_seqs = query.shape[0] + num_heads = query.shape[1] + head_size = query.shape[2] + block_size = value_cache.shape[3] + max_seq_len = int(seq_lens.max().item()) + + # Create position indices for all sequences up to max_seq_len + positions = torch.arange(max_seq_len, device=query.device) + block_indices = positions // block_size # (max_seq_len,) + block_offsets = positions % block_size # (max_seq_len,) + + # Gather block numbers for all sequences: (num_seqs, max_seq_len) + block_numbers = block_tables[:, block_indices.long()] + + # Flatten for gathering: (num_seqs * max_seq_len,) + flat_block_numbers = block_numbers.reshape(-1) + flat_offsets = block_offsets.repeat(num_seqs) + + # Gather keys: key_cache is (num_blocks, num_heads, head_size // x, block_size, x) + # Index into [block_number, :, :, offset, :] and reshape + keys = key_cache[flat_block_numbers, :, :, flat_offsets, :] + keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size) + keys = keys.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size) + + # Gather values: value_cache is (num_blocks, num_heads, head_size, block_size) + values = value_cache[flat_block_numbers, :, :, flat_offsets] + values = values.reshape(num_seqs, max_seq_len, num_heads, head_size) + values = values.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size) + + # Query: (num_seqs, num_heads, head_size) -> (num_seqs, num_heads, 1, head_size) + q = query.unsqueeze(2) + + # Compute attention scores: (num_seqs, num_heads, 1, head_size) @ (num_seqs, num_heads, head_size, max_seq_len) + attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float() + + # Create causal mask for variable sequence lengths + # Mask out positions beyond seq_len for each sequence + seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze( + 1 + ) # (num_seqs, max_seq_len) + seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) # (num_seqs, 1, 1, max_seq_len) + attn_weights = attn_weights.masked_fill(seq_mask, float("-inf")) + + attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype) + + # Compute output: (num_seqs, num_heads, 1, max_seq_len) @ (num_seqs, num_heads, max_seq_len, head_size) + out = torch.matmul(attn_weights, values) + + return out.squeeze(2) # (num_seqs, num_heads, head_size) + + +class PagedAttentionBenchmark(Benchmark): + seed: int = 42 + + def setup(self): + num_seqs = 4 + num_heads = 8 + head_size = 64 + block_size = 16 + max_seq_len = 128 + num_blocks = 64 + dtype = torch.float16 + + self.num_heads = num_heads + self.block_size = block_size + self.max_seq_len = max_seq_len + self.scale = 1.0 / (head_size**0.5) + + # Query tensor (current token) + self.query = torch.randn( + num_seqs, num_heads, head_size, device=self.device, dtype=dtype + ) + + # KV cache with proper layout for the kernel + # x = 16 // element_size, for float16 x = 8 + x = 16 // torch.tensor([], dtype=dtype).element_size() + self.key_cache = torch.randn( + num_blocks, + num_heads, + head_size // x, + block_size, + x, + device=self.device, + dtype=dtype, + ) + self.value_cache = torch.randn( + num_blocks, + num_heads, + head_size, + block_size, + device=self.device, + dtype=dtype, + ) + + # Block tables: mapping from sequences to memory blocks + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + self.block_tables = torch.randint( + 0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + device=self.device, + dtype=torch.int32, + ) + + # Sequence lengths + self.seq_lens = torch.tensor( + [64, 96, 48, 128], device=self.device, dtype=torch.int32 + ) + + # KV scales + self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) + self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) + + # Output tensor + self.out = torch.empty_like(self.query) + + def benchmark_base(self): + self.kernel.paged_attention_v1( + self.out, + self.query, + self.key_cache, + self.value_cache, + num_kv_heads=self.num_heads, + scale=self.scale, + block_tables=self.block_tables, + seq_lens=self.seq_lens, + block_size=self.block_size, + max_seq_len=self.max_seq_len, + alibi_slopes=None, + kv_cache_dtype="auto", + k_scale=self.k_scale, + v_scale=self.v_scale, + ) + + def verify_base(self) -> torch.Tensor: + return ref_paged_attention( + self.query, + self.key_cache, + self.value_cache, + self.block_tables, + self.seq_lens, + self.scale, + ) + + def setup_large(self): + num_seqs = 16 + num_heads = 32 + head_size = 128 + block_size = 16 + max_seq_len = 512 + num_blocks = 256 + dtype = torch.float16 + + self.num_heads = num_heads + self.block_size = block_size + self.max_seq_len = max_seq_len + self.scale = 1.0 / (head_size**0.5) + + self.query = torch.randn( + num_seqs, num_heads, head_size, device=self.device, dtype=dtype + ) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + self.key_cache = torch.randn( + num_blocks, + num_heads, + head_size // x, + block_size, + x, + device=self.device, + dtype=dtype, + ) + self.value_cache = torch.randn( + num_blocks, + num_heads, + head_size, + block_size, + device=self.device, + dtype=dtype, + ) + + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + self.block_tables = torch.randint( + 0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + device=self.device, + dtype=torch.int32, + ) + + # Variable sequence lengths + self.seq_lens = torch.randint( + 64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32 + ) + + self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) + self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device) + + self.out = torch.empty_like(self.query) + + def benchmark_large(self): + self.kernel.paged_attention_v1( + self.out, + self.query, + self.key_cache, + self.value_cache, + num_kv_heads=self.num_heads, + scale=self.scale, + block_tables=self.block_tables, + seq_lens=self.seq_lens, + block_size=self.block_size, + max_seq_len=self.max_seq_len, + alibi_slopes=None, + kv_cache_dtype="auto", + k_scale=self.k_scale, + v_scale=self.v_scale, + ) + + def verify_large(self) -> torch.Tensor: + return ref_paged_attention( + self.query, + self.key_cache, + self.value_cache, + self.block_tables, + self.seq_lens, + self.scale, + ) diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..601ee921f6f821552fb2525eca5f5a682ca10286 --- /dev/null +++ b/build.toml @@ -0,0 +1,110 @@ +[general] +name = "paged_attention" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h" +] + +[kernel.cuda_utils] +backend = "cuda" +src = [ + "cuda-utils/cuda_utils.h", + "cuda-utils/cuda_utils_kernels.cu", +] +depends = [] + +[kernel.cuda_utils_rocm] +backend = "rocm" +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +src = [ + "cuda-utils/cuda_utils.h", + "cuda-utils/cuda_utils_kernels.cu", +] +depends = ["torch"] + +[kernel.paged_attention] +backend = "cuda" +src = [ + "cuda-utils/cuda_utils.h", + "paged-attention/attention/attention_dtypes.h", + "paged-attention/attention/attention_generic.cuh", + "paged-attention/attention/attention_kernels.cuh", + "paged-attention/attention/attention_utils.cuh", + "paged-attention/attention/dtype_bfloat16.cuh", + "paged-attention/attention/dtype_float16.cuh", + "paged-attention/attention/dtype_float32.cuh", + "paged-attention/attention/dtype_fp8.cuh", + "paged-attention/attention/paged_attention_v1.cu", + "paged-attention/attention/paged_attention_v2.cu", + "paged-attention/cache_kernels.cu", + "paged-attention/cuda_compat.h", + "paged-attention/dispatch_utils.h", + "paged-attention/quantization/fp8/amd/quant_utils.cuh", + "paged-attention/quantization/fp8/nvidia/quant_utils.cuh", +] +include = [ "cuda-utils", "paged-attention" ] +depends = [ "torch" ] + +[kernel.paged_attention_rocm] +backend = "rocm" +rocm-archs = [ + "gfx906", + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx1030", + "gfx1100", + "gfx1101", +] +src = [ + "cuda-utils/cuda_utils.h", + "paged-attention/attention/attention_dtypes.h", + "paged-attention/attention/attention_generic.cuh", + "paged-attention/attention/attention_kernels.cuh", + "paged-attention/attention/attention_utils.cuh", + "paged-attention/attention/dtype_bfloat16.cuh", + "paged-attention/attention/dtype_float16.cuh", + "paged-attention/attention/dtype_float32.cuh", + "paged-attention/attention/dtype_fp8.cuh", + "paged-attention/attention/paged_attention_v1.cu", + "paged-attention/attention/paged_attention_v2.cu", + "paged-attention/cache_kernels.cu", + "paged-attention/cuda_compat.h", + "paged-attention/dispatch_utils.h", + "paged-attention/quantization/fp8/amd/quant_utils.cuh", + "paged-attention/quantization/fp8/nvidia/quant_utils.cuh", +] +include = [ "cuda-utils", "paged-attention" ] +depends = [ "torch" ] + +[kernel.paged_attention_metal] +backend = "metal" +src = [ + "paged-attention-metal/attention/paged_attention.metal", + "paged-attention-metal/cache/copy_blocks.metal", + "paged-attention-metal/cache/reshape_and_cache.metal", + "paged-attention-metal/convert_fp8.metal", + "paged-attention-metal/float8.metal", + "paged-attention-metal/utils.metal", + "paged-attention-metal/paged_attention.mm", + "paged-attention-metal/cache.mm", + "paged-attention-metal/convert_fp8.mm", + "paged-attention-metal/device.mm", +] +include = [ "." ] +depends = [ "torch" ] diff --git a/build/torch210-cxx11-cu126-aarch64-linux/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_custom_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_ops.py b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..47b0655e6e7f3403e78a6766e76b088885bc031a --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33d5f8b98a2a171fee0e0106dfd9174438e40cbea4d13f0f53105a0c0d49695b +size 140013424 diff --git a/build/torch210-cxx11-cu126-aarch64-linux/metadata.json b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch210-cxx11-cu126-aarch64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu126-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-aarch64-linux/platforms.py b/build/torch210-cxx11-cu126-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu126-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_custom_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4e5743eeb6fd4727e8010f46ef07bee65d3ec3be --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f84331fb1023844b101c03c8f12818bb3b09c273a9442b631cc2efe87b1eee2f +size 140162704 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/metadata.json b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch210-cxx11-cu126-x86_64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu126-x86_64-linux/platforms.py b/build/torch210-cxx11-cu126-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_custom_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0de6424a644af7d827892316d3375c80b14544d7 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6dd2622118e8d4a9e7d952da74ffdb90627c4bb7a76a3be349847427b43db1dd +size 167603936 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu128-aarch64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/platforms.py b/build/torch210-cxx11-cu128-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_custom_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8f40d03d0f588863948c95624c7c4d11de369f13 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37c0783a4a3628ffc43d64b65090cb4fa8b2f5cc2fe913a51901378f518d11af +size 167726096 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu128-x86_64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/platforms.py b/build/torch210-cxx11-cu128-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_custom_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..673701a2eb963f3cda151219afa807616ba73ab0 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31b7d92afaaffa6d335dad007ca97f76c66a5470e6a380e03a93fca6ff2232dc +size 86068816 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/platforms.py b/build/torch210-cxx11-cu130-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_custom_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..dde5933b230b38b341af6c607fb6f46b0d7a362d --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fc05b440e24ece432bd009e23dbf721d191d03cfa3d020c2d52d3eaface9992 +size 86563792 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch210-cxx11-cu130-x86_64-linux/paged_attention/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/platforms.py b/build/torch210-cxx11-cu130-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_custom_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1268908ac846f37bef7874170ff1f06c3eb33c9 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_rocm_83cf4a3 +ops = torch.ops._paged_attention_rocm_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_rocm_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..9f143adeaeb851f7d7146e44477410ceb11f17f6 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c715078de15626c6dc53b2bb321828478a33952ed5bac5e6f5730a984445b321 +size 58992416 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..3e8d811f1dc42febd33121b2627f809447622baf --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/metadata.json @@ -0,0 +1,17 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "rocm", + "archs": [ + "gfx1030", + "gfx1100", + "gfx1101", + "gfx906", + "gfx908", + "gfx90a", + "gfx942" + ] + } +} diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/paged_attention/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/platforms.py b/build/torch210-cxx11-rocm70-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_custom_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1268908ac846f37bef7874170ff1f06c3eb33c9 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_rocm_83cf4a3 +ops = torch.ops._paged_attention_rocm_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_rocm_83cf4a3::{op_name}" diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f44ee883d362fc9df58844496faa3cc079670ddb --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e531858b1c7c996812b84a6801d240f9fa650ccc078f74b706d72414bae7965 +size 58971840 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..3e8d811f1dc42febd33121b2627f809447622baf --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/metadata.json @@ -0,0 +1,17 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "rocm", + "archs": [ + "gfx1030", + "gfx1100", + "gfx1101", + "gfx906", + "gfx908", + "gfx90a", + "gfx942" + ] + } +} diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/paged_attention/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/platforms.py b/build/torch210-cxx11-rocm71-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch210-metal-aarch64-darwin/__init__.py b/build/torch210-metal-aarch64-darwin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch210-metal-aarch64-darwin/_custom_ops.py b/build/torch210-metal-aarch64-darwin/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch210-metal-aarch64-darwin/_ops.py b/build/torch210-metal-aarch64-darwin/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b70aaa4f8859a5ed30a12fdf876196d99bea499a --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_metal_c6404f1 +ops = torch.ops._paged_attention_metal_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_metal_c6404f1::{op_name}" diff --git a/build/torch210-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so b/build/torch210-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..db144491c787d08536b4636ad13376ea08f4848a --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b99454be1d50260beb973bd80b5fb2af61ac288505840e164461ce9239b0ed +size 14893560 diff --git a/build/torch210-metal-aarch64-darwin/metadata.json b/build/torch210-metal-aarch64-darwin/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a5381dd80836f863378b9f33a559815688de9287 --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/metadata.json @@ -0,0 +1,5 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch210-metal-aarch64-darwin/paged_attention/__init__.py b/build/torch210-metal-aarch64-darwin/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-metal-aarch64-darwin/platforms.py b/build/torch210-metal-aarch64-darwin/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch210-metal-aarch64-darwin/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu126-aarch64-linux/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_custom_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_ops.py b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b9fa2e021fd0d633cc5bacf298696363cd6662c0 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a07d7840081c430ba902d82c3cd1f2c6164022003a096c97f57b6559121d53e +size 140009304 diff --git a/build/torch211-cxx11-cu126-aarch64-linux/metadata.json b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch211-cxx11-cu126-aarch64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu126-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-aarch64-linux/platforms.py b/build/torch211-cxx11-cu126-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu126-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_custom_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_ops.py b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b229c2d2dbaf0ef18a2642c5b50ae146274537b2 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b092432ec63eed6ffba55b648a0d484c3dfe58b6b712cea9dfb527dc38066dc +size 140147224 diff --git a/build/torch211-cxx11-cu126-x86_64-linux/metadata.json b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch211-cxx11-cu126-x86_64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu126-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu126-x86_64-linux/platforms.py b/build/torch211-cxx11-cu126-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu126-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_custom_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ebd8737ef56102010043dee9d4ba76835505df26 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa20b2f6c5490a230b408034a1f6a93ace0208efdb412366ab1b106c49a1e14c +size 167599784 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu128-aarch64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/platforms.py b/build/torch211-cxx11-cu128-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_custom_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4c85ca28c445e3c99c6ae70566284805bdaf39c1 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:669f48c97e2970d820919e972762bf0cc64bc657a53a6b7452de5dc9fd1e4844 +size 167710616 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/platforms.py b/build/torch211-cxx11-cu128-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_custom_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4d0421362ad75a7eb452b49d8b94862019bfc0a9 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82512281eb292163d884cee057ad3bd32dc3ea4eebb676acbb964107d5ad1c5e +size 86064664 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-aarch64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/platforms.py b/build/torch211-cxx11-cu130-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_custom_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..86312eb60e5b05ca1a676bfd68ad28abc31489c7 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df2b4931f36d5314b23cf3dd3a718665c9c84ad055ca2d844adb8e34be4aac24 +size 86548312 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/paged_attention/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/platforms.py b/build/torch211-cxx11-cu130-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a42f6c23989d05e0c4180a8e07ce104a7ecd99c Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c698ce4a5f2b3db77ed9861efc143265073f6a8 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4cbb5b24a571485fb2d7712b46631a1eedf0b41 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_custom_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a883d5cb5d351eb18b40e1d4b621a7d1544385ab --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_0041e3f +ops = torch.ops._paged_attention_0041e3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_0041e3f::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..71626cdbf291f8b39d80a6b511439c2148f25112 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b899f376425b8d7a213b8d26909d84ffb1213e03d3e5b33675e9408426747501 +size 113844912 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/platforms.py b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..428ec0cbce6fbd264193a66297408d47d9f9882b Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4257fb3f872d1ec4e346d6f3c5ec2046ab82f36 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987598889f8379c07efa94bad6a3b6f17dec9860 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_custom_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a883d5cb5d351eb18b40e1d4b621a7d1544385ab --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_0041e3f +ops = torch.ops._paged_attention_0041e3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_0041e3f::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d68462effa6cb3bc865b2cdca6137261d51c2edc --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f98a11d55d75ee841515253dfd16b5dd17a811925a445ef765b90b4b56a10e35 +size 110732296 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/platforms.py b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__init__.py b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4426d89b40d2875c6242745d3a3404dc105dc340 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76058206a88f8f66740efd7b611f125a5f3e1eee Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0056674c98c61fb9850920022abc1bbebebdb2e2 Binary files /dev/null and b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_custom_ops.py b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_ops.py b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e9f56ff02bdcf047fd97ef2a46b10d6fb5e5eb --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_b4c51e9 +ops = torch.ops._paged_attention_b4c51e9 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_b4c51e9::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..31455c6c5a159469575f362d673d272a341e4b19 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cba9e7b3cf9de722c5dc8f56a533542521fdef05a290a2b4db13948b0f1eca4 +size 138172880 diff --git a/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/platforms.py b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c13b4dd0a5c62749c4fffaed83061e327465cfa Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb10ad97588c556a857e62d03549f1c72efeced9 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbe695e89609fad8076312195af5a04690a744ee Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_custom_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a883d5cb5d351eb18b40e1d4b621a7d1544385ab --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_0041e3f +ops = torch.ops._paged_attention_0041e3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_0041e3f::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e2fbd54b5c6e0379de056949d3d7991b31b546e6 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15f5767adca82992a20bb3b0ff047e5cd076c386f1bdff70f63c3052c47e1ac3 +size 138291040 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/platforms.py b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a487f6dd62db7a85db734773b9718a43478a689 Binary files /dev/null and b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ae1d42b0f5f3fb681398076bef151d34cc9325 Binary files /dev/null and b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a4ea49b96acc139399eedbd529c4af03da4944b Binary files /dev/null and b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_custom_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a883d5cb5d351eb18b40e1d4b621a7d1544385ab --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_0041e3f +ops = torch.ops._paged_attention_0041e3f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_0041e3f::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f8326ce54482175c1763ff12b7789f71c61dc9ca --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/_paged_attention_0041e3f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5955ede80b62f77cbcf76ffc3961b20af2d5a17ceb95c5c2fc06e7b12e9d3bec +size 120178304 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/platforms.py b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/__init__.py b/build/torch27-metal-aarch64-darwin/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/_custom_ops.py b/build/torch27-metal-aarch64-darwin/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/_ops.py b/build/torch27-metal-aarch64-darwin/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a82eda58b197517abf459ac6adaefaf817431b --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_9678b89 +ops = torch.ops._paged_attention_9678b89 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_9678b89::{op_name}" \ No newline at end of file diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.abi3.so b/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..d6c43e8f7b7b269de282330ba9c877243937a713 --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a94cee9e553d2bdf8d47d0d9461c871b3e57a33cf6cb259807377f0d1b03c7d +size 214800 diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.metallib b/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.metallib new file mode 100644 index 0000000000000000000000000000000000000000..076fc55c592b25de05c25232c560b617986561cc --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.metallib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c46eaf21c96da70c5227b2566308a8ef73ae09abf303278f40070dd4326ba0be +size 4999876 diff --git a/build/torch27-metal-aarch64-darwin/paged_attention/platforms.py b/build/torch27-metal-aarch64-darwin/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch27-metal-aarch64-darwin/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_custom_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_paged_attention_7c06fa7.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d846ed81c7f6d3218dee88b09cf9a283a304b8f6 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3e8736d3b0efd5435d028d3f4429ebcc8fb95bf9dcaa3d7b191ebcea5ad20c2 +size 140143504 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/metadata.json b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/paged_attention/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/platforms.py b/build/torch28-cxx11-cu126-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_custom_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_paged_attention_7c06fa7.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ab267a807aba88bf8eed5bc0adbec038e4b8c42e --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba3a28b63cf7a420d5f1622beec6b8d916254484f8e8d1629f33ee2c5ae4a7a0 +size 167710768 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/metadata.json b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/paged_attention/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/platforms.py b/build/torch28-cxx11-cu128-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__init__.py b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2029d1fa1af1cacf90cbe0e8a2a59b907940e8d Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f54c4683b6c5a5d05d435d40f546ac8cbf62e76 Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_custom_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..095bc0f18dddbaa40b04468fa6cb9dde32b9285a Binary files /dev/null and b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_custom_ops.py b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_ops.py b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e9f56ff02bdcf047fd97ef2a46b10d6fb5e5eb --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_b4c51e9 +ops = torch.ops._paged_attention_b4c51e9 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_b4c51e9::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f4f21766fa2c547675a1373926d281b4ec7a28fa --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/_paged_attention_b4c51e9.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02238a0a4dacdbbf60eb9bb73a95332b448d5127f71149003c764f40595d9d06 +size 149841048 diff --git a/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/platforms.py b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-cu129-aarch64-linux/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_custom_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_paged_attention_7c06fa7.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f7fbade89d1c107240d641411d712cf7829dbbe8 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b20a5cd5515e0a521c241bd4c7d659b82e6393217ebb86b7e8360d8b19949994 +size 182084600 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/metadata.json b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/paged_attention/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/platforms.py b/build/torch28-cxx11-cu129-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_custom_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_paged_attention_7c06fa7.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0b4395723360f06cb1c5968a4c4818ab6085a2f5 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e61cf1b792a2479d48a2c886b30abf364efa435ced959dbfca03a619dbdd47e6 +size 120192848 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/platforms.py b/build/torch28-cxx11-rocm63-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_custom_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_paged_attention_7c06fa7.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8a198fa60c375cd9a63533504b065f74bd6667c9 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39d897931a7f06c168722ff3ec836eb923d528f2648cf801e55473bc726a0e5b +size 121030544 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/platforms.py b/build/torch28-cxx11-rocm64-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch28-metal-aarch64-darwin/__init__.py b/build/torch28-metal-aarch64-darwin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch28-metal-aarch64-darwin/_custom_ops.py b/build/torch28-metal-aarch64-darwin/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch28-metal-aarch64-darwin/_ops.py b/build/torch28-metal-aarch64-darwin/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..db595747ff801b0cf8021fcbb30693389cbebb94 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_7c06fa7 +ops = torch.ops._paged_attention_7c06fa7 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_7c06fa7::{op_name}" \ No newline at end of file diff --git a/build/torch28-metal-aarch64-darwin/_paged_attention_7c06fa7.abi3.so b/build/torch28-metal-aarch64-darwin/_paged_attention_7c06fa7.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..410ba8a85bdcc005471f889cbb19cf47e9ab3270 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/_paged_attention_7c06fa7.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0626edd9c58bef7c74b32cbdb73219d4e49fb650a0098726c4856dab33888306 +size 14892816 diff --git a/build/torch28-metal-aarch64-darwin/metadata.json b/build/torch28-metal-aarch64-darwin/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..9cf5deed9898dce769f4cc73913d3530b92a0bd8 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/metadata.json @@ -0,0 +1,4 @@ +{ + "version": 1, + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch28-metal-aarch64-darwin/paged_attention/__init__.py b/build/torch28-metal-aarch64-darwin/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch28-metal-aarch64-darwin/platforms.py b/build/torch28-metal-aarch64-darwin/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch28-metal-aarch64-darwin/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_custom_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..1760b475fbbb0d8a67982aaccbd8d83fc598157c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:beaa2bb81fd77f1a1bd29447e875044d741aa5069d88beff50234a45a0dee32e +size 140009952 diff --git a/build/torch29-cxx11-cu126-aarch64-linux/metadata.json b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/platforms.py b/build/torch29-cxx11-cu126-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..d5a2175146a85e7a9bfb143f4ad583834a0069fc --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89325f1b97efcf0d22e55ee4a68bb7590296190384a5af34702b355566b9140c +size 140143504 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/metadata.json b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..f5902b55ab0b2b561c0cf97567c9806c60839c7f --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/metadata.json @@ -0,0 +1,18 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0+PTX" + ] + } +} diff --git a/build/torch29-cxx11-cu126-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/platforms.py b/build/torch29-cxx11-cu126-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_custom_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..e4641642e57ed6726f75d4e867b337f0f01ad4f8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5708848d9e6113b481537edecd14c115844ab223acb66b64dcf78148cd7f2d1 +size 167599904 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/platforms.py b/build/torch29-cxx11-cu128-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8357452f06bd5c4993903fffd65accaaaadc78cf --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb574b6d123b5a419063568d05231969f4a70a5fb2d4903c9d1df60c8d5fc9a2 +size 167710768 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu128-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/platforms.py b/build/torch29-cxx11-cu128-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_custom_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..2fe64b35d5b58502b399262232da5c8b27668b66 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06b32dd66f35fc1349135af5a24d21d2465056a59b43921c0ad53090e5700fd2 +size 181953688 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/platforms.py b/build/torch29-cxx11-cu129-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c7088500e6944a3995cfb559a03b4dfd120133fe --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_83cf4a3 +ops = torch.ops._paged_attention_cuda_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_83cf4a3::{op_name}" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f0e08ecae2016f3fd6ee2b3c1fd5b301bd7fac91 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_paged_attention_cuda_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:914b49c90182aad733f2f3060f3738aac9faf4baafe7fb308a3f1d7aacea3cf2 +size 182088696 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8b796af185fbbd8594fcd846949aa5fadc0ccdda --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,21 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "10.1", + "12.0+PTX", + "7.0", + "7.2", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/platforms.py b/build/torch29-cxx11-cu129-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_custom_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..841115a9b76b86e9d9e274b6fe42de557182c2c9 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:debaeb20cb67b2c4f707a7f2370400d34175255f07dc548449fe924ead7477d7 +size 86064784 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/platforms.py b/build/torch29-cxx11-cu130-aarch64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5415b95a504475517ce0b451729d8edaaf24042a --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_cuda_c6404f1 +ops = torch.ops._paged_attention_cuda_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_cuda_c6404f1::{op_name}" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..fd6bfea111d9609c0fadb2079e517edc916dcf3c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_paged_attention_cuda_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8b109c24c9a2f74e945b6d5bd8634b53b791fbc3ad60bbe7304e2f52097f51e +size 86544216 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..66651b7d3f95ac9e5ce5fc2a641b6f0f50788f87 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,19 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0", + "11.0", + "12.0+PTX", + "7.5", + "8.0", + "8.6", + "8.7", + "8.9", + "9.0" + ] + } +} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/platforms.py b/build/torch29-cxx11-cu130-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1268908ac846f37bef7874170ff1f06c3eb33c9 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_rocm_83cf4a3 +ops = torch.ops._paged_attention_rocm_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_rocm_83cf4a3::{op_name}" diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..639195aac43e3055b6a25ba6e0692c453571e719 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef3e8ae335bb6a87e6f97b0c472cc3aa39c6176240eec858fdcac23b55653118 +size 58641744 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..3e8d811f1dc42febd33121b2627f809447622baf --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/metadata.json @@ -0,0 +1,17 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "rocm", + "archs": [ + "gfx1030", + "gfx1100", + "gfx1101", + "gfx906", + "gfx908", + "gfx90a", + "gfx942" + ] + } +} diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/platforms.py b/build/torch29-cxx11-rocm63-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_custom_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1268908ac846f37bef7874170ff1f06c3eb33c9 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_rocm_83cf4a3 +ops = torch.ops._paged_attention_rocm_83cf4a3 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_rocm_83cf4a3::{op_name}" diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..9c83dc7ea0dfcf90e995788d714b579c2dc1cdc8 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_paged_attention_rocm_83cf4a3.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63144f72f1d8e1d225446f57f89fcdd0e6cce615cf5f6d4c37aa6f33772f292e +size 57972360 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..3e8d811f1dc42febd33121b2627f809447622baf --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/metadata.json @@ -0,0 +1,17 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "rocm", + "archs": [ + "gfx1030", + "gfx1100", + "gfx1101", + "gfx906", + "gfx908", + "gfx90a", + "gfx942" + ] + } +} diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/platforms.py b/build/torch29-cxx11-rocm64-x86_64-linux/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/build/torch29-metal-aarch64-darwin/__init__.py b/build/torch29-metal-aarch64-darwin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/build/torch29-metal-aarch64-darwin/_custom_ops.py b/build/torch29-metal-aarch64-darwin/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/build/torch29-metal-aarch64-darwin/_ops.py b/build/torch29-metal-aarch64-darwin/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b70aaa4f8859a5ed30a12fdf876196d99bea499a --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _paged_attention_metal_c6404f1 +ops = torch.ops._paged_attention_metal_c6404f1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_paged_attention_metal_c6404f1::{op_name}" diff --git a/build/torch29-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so b/build/torch29-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ed6528f2606d1e91eb6e13f693d22c7c357a962e --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/_paged_attention_metal_c6404f1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92e72671d13cdb84ccabc5bbc95e4b6e11772708e4e948e6a15974b86e80e26e +size 14893128 diff --git a/build/torch29-metal-aarch64-darwin/metadata.json b/build/torch29-metal-aarch64-darwin/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a5381dd80836f863378b9f33a559815688de9287 --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/metadata.json @@ -0,0 +1,5 @@ +{ + "version": 1, + "license": "Apache-2.0", + "python-depends": [] +} \ No newline at end of file diff --git a/build/torch29-metal-aarch64-darwin/paged_attention/__init__.py b/build/torch29-metal-aarch64-darwin/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/paged_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-metal-aarch64-darwin/platforms.py b/build/torch29-metal-aarch64-darwin/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/build/torch29-metal-aarch64-darwin/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/cuda-utils/cuda_utils.h b/cuda-utils/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6e62ea208db883b9588a2baf7f35a310376c95c1 --- /dev/null +++ b/cuda-utils/cuda_utils.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +int64_t get_device_attribute(int64_t attribute, int64_t device_id); + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/cuda-utils/cuda_utils_kernels.cu b/cuda-utils/cuda_utils_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..0627a42675b524ae5f8d73ad3d180899e777c5c0 --- /dev/null +++ b/cuda-utils/cuda_utils_kernels.cu @@ -0,0 +1,35 @@ +#include "cuda_utils.h" +#ifdef USE_ROCM + #include + #include +#endif + +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { + // Return the cached value on subsequent calls + static int value = [=]() { + int device = static_cast(device_id); + if (device < 0) { + CUDA_CHECK(cudaGetDevice(&device)); + } + int value; + CUDA_CHECK(cudaDeviceGetAttribute( + &value, static_cast(attribute), device)); + return static_cast(value); + }(); + + return value; +} + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 + +#ifdef USE_ROCM + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; +#else + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; +#endif + + return get_device_attribute(attribute, device_id); +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..aa644f4f01e9421697ee8f13d337279a27cd6f93 --- /dev/null +++ b/flake.lock @@ -0,0 +1,168 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1759493343, + "narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1759823856, + "narHash": "sha256-wxnoxY8Whem8NmE5UDOkv74puV68CoAMIRpVMCuKjJ8=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "62a9652083262f49c11a1eb1aa2cbbbcf2170c93", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755963616, + "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..e723f3b3f6a02ec23345a353b4352ff3727735f1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for attention kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/paged-attention-metal/attention/paged_attention.metal b/paged-attention-metal/attention/paged_attention.metal new file mode 100644 index 0000000000000000000000000000000000000000..22d972d1811af6ec77e10ef0058e7ac98d0446f5 --- /dev/null +++ b/paged-attention-metal/attention/paged_attention.metal @@ -0,0 +1,1401 @@ +// Updated from MLX commit has f70764a + +#include "../utils.metal" +#include "../float8.metal" +#include +#include + +using namespace metal; + +// ========================================== Generic vector types + +// A vector type to store Q, K, V elements. +template struct Vec {}; + +// A vector type to store FP32 accumulators. +template struct FloatVec {}; + +// Template vector operations. +template inline Acc mul(A a, B b); + +template inline float sum(T v); + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +template inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +// FP32 vector data types. +struct Float8_ { + float4 x; + float4 y; +}; + +template <> struct Vec { + using Type = float; +}; +template <> struct Vec { + using Type = float2; +}; +template <> struct Vec { + using Type = float4; +}; +template <> struct Vec { + using Type = Float8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(float a, float b) { return a * b; } + +template <> inline float2 mul(float2 a, float2 b) { return a * b; } + +template <> inline float4 mul(float4 a, float4 b) { return a * b; } + +template <> inline Float8_ mul(Float8_ a, Float8_ b) { + Float8_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float sum(float a) { return a; } + +template <> inline float sum(float2 a) { return a.x + a.y; } + +template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); } + +inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { + Float8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread float &dst, float src) { dst = src; } +inline void from_float(thread float2 &dst, float2 src) { dst = src; } +inline void from_float(thread float4 &dst, float4 src) { dst = src; } +inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; } + +// BF16 vector data types. +// #if defined(__HAVE_BFLOAT__) + +// struct Bfloat8_ { +// bfloat4 x; +// bfloat4 y; +// }; + +// template<> +// struct Vec { +// using Type = bfloat; +// }; +// template<> +// struct Vec { +// using Type = bfloat2; +// }; +// template<> +// struct Vec { +// using Type = bfloat4; +// }; +// template<> +// struct Vec { +// using Type = Bfloat8_; +// }; + +// template<> +// struct FloatVec { +// using Type = float; +// }; +// template<> +// struct FloatVec { +// using Type = float2; +// }; +// template<> +// struct FloatVec { +// using Type = float4; +// }; +// template<> +// struct FloatVec { +// using Type = Float8_; +// }; + +// template<> +// inline float mul(bfloat a, bfloat b) { +// return (float)a * (float)b; +// } +// template<> +// inline bfloat mul(bfloat a, bfloat b) { +// return a*b; +// } + +// template<> +// inline float2 mul(bfloat2 a, bfloat2 b) { +// return (float2)a * (float2)b; +// } +// template<> +// inline bfloat2 mul(bfloat2 a, bfloat2 b) { +// return a * b; +// } + +// template<> +// inline float4 mul(bfloat4 a, bfloat4 b) { +// return (float4)a * (float4)b; +// } +// template<> +// inline bfloat4 mul(bfloat4 a, bfloat4 b) { +// return a * b; +// } + +// template<> +// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Float8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } +// template<> +// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Bfloat8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } + +// template<> +// inline float sum(bfloat a) { +// return (float)a; +// } + +// template<> +// inline float sum(bfloat2 a) { +// return (float)a.x + (float)a.y; +// } + +// template<> +// inline float sum(bfloat4 a) { +// return sum(a.x) + sum(a.y); +// } + +// template<> +// inline float sum(Bfloat8_ a) { +// return sum(a.x) + sum(a.y); +// } + +// inline float fma(bfloat a, bfloat b, float c) { +// return (float)a * (float)b + c; +// } + +// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) { +// return (float2)a * (float2)b + c; +// } + +// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) { +// return (float4)a * (float4)b + c; +// } + +// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { +// Float8_ res; +// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y); +// return res; +// } +// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { +// Bfloat8_ res; +// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y); +// return c; +// } + +// inline void from_float(thread bfloat& dst, float src) { +// dst = static_cast(src); +// } +// inline void from_float(thread bfloat2& dst, float2 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// } +// inline void from_float(thread bfloat4& dst, float4 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// dst.z = static_cast(src.z); +// dst.w = static_cast(src.w); +// } +// inline void from_float(thread Bfloat8_& dst, Float8_ src) { +// bfloat4 x; +// bfloat4 y; +// from_float(x, src.x); +// from_float(y, src.y); +// dst.x = x; +// dst.y = y; +// } + +// #else + +struct Bfloat2_ { + bfloat16_t x; + bfloat16_t y; +}; + +struct Bfloat4_ { + Bfloat2_ x; + Bfloat2_ y; +}; + +struct Bfloat8_ { + Bfloat4_ x; + Bfloat4_ y; +}; + +template <> struct Vec { + using Type = bfloat16_t; +}; +template <> struct Vec { + using Type = Bfloat2_; +}; +template <> struct Vec { + using Type = Bfloat4_; +}; +template <> struct Vec { + using Type = Bfloat8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(bfloat16_t a, bfloat16_t b) { + return (float)a * (float)b; +} +template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; } + +template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f; +} +template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { + Bfloat2_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) { + float2 x = mul(a.x, b.x); + float2 y = mul(a.y, b.y); + float4 c; + c.x = x.x; + c.y = x.y; + c.z = y.x; + c.w = y.y; + return c; +} +template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { + Bfloat4_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { + Float8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} +template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { + Bfloat8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(bfloat16_t a) { return (float)a; } + +template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); } + +template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(bfloat16_t a, bfloat16_t b, float c) { + return (float)a * (float)b + c; +} +inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { + return a * b + c; +} + +inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f + c; +} +inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) { + Bfloat2_ res; + res.x = a.x * b.x + c.x; + res.y = a.y * b.y + c.y; + return res; +} + +inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) { + float4 res; + res.x = fma(a.x.x, b.x.x, c.x); + res.y = fma(a.x.y, b.x.y, c.y); + res.z = fma(a.y.x, b.y.x, c.z); + res.w = fma(a.y.y, b.y.y, c.w); + return res; +} +inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) { + Bfloat4_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { + Bfloat8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread bfloat16_t &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread Bfloat2_ &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread Bfloat4_ &dst, float4 src) { + dst.x.x = static_cast(src.x); + dst.x.y = static_cast(src.y); + dst.y.x = static_cast(src.z); + dst.y.y = static_cast(src.w); +} +inline void from_float(thread Bfloat8_ &dst, Float8_ src) { + Bfloat4_ x; + Bfloat4_ y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// #endif + +// FP16 vector data types. +struct Half8_ { + half4 x; + half4 y; +}; + +template <> struct Vec { + using Type = half; +}; +template <> struct Vec { + using Type = half2; +}; +template <> struct Vec { + using Type = half4; +}; +template <> struct Vec { + using Type = Half8_; +}; + +template <> struct FloatVec { + using Type = float; +}; +template <> struct FloatVec { + using Type = float2; +}; +template <> struct FloatVec { + using Type = float4; +}; +template <> struct FloatVec { + using Type = Float8_; +}; + +template <> inline float mul(half a, half b) { return (float)a * (float)b; } +template <> inline half mul(half a, half b) { return a * b; } + +template <> inline float2 mul(half2 a, half2 b) { + return (float2)a * (float2)b; +} +template <> inline half2 mul(half2 a, half2 b) { return a * b; } + +template <> inline float4 mul(half4 a, half4 b) { + return (float4)a * (float4)b; +} +template <> inline half4 mul(half4 a, half4 b) { return a * b; } + +template <> inline Float8_ mul(Half8_ a, Half8_ b) { + float4 x = mul(a.x, b.x); + float4 y = mul(a.y, b.y); + Float8_ c; + c.x = x; + c.y = y; + return c; +} +template <> inline Half8_ mul(Half8_ a, Half8_ b) { + Half8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> inline float sum(half a) { return (float)a; } + +template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; } + +template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; } + +template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); } + +inline float fma(half a, half b, float c) { return (float)a * (float)b + c; } + +inline float2 fma(half2 a, half2 b, float2 c) { + return (float2)a * (float2)b + c; +} + +inline float4 fma(half4 a, half4 b, float4 c) { + return (float4)a * (float4)b + c; +} + +inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { + Half8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread half &dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread half2 &dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread half4 &dst, float4 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); + dst.z = static_cast(src.z); + dst.w = static_cast(src.w); +} +inline void from_float(thread Half8_ &dst, Float8_ src) { + half4 x; + half4 y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// ========================================== FP8 (uchar) vector data types. + +// 8‑lane uchar vector – Metal only provides up to uchar4, so build our own. +struct Uchar8_ { + uchar4 x; + uchar4 y; +}; + +// Vec specialisations so Vec::Type resolves correctly. +template <> struct Vec { + using Type = uchar; +}; +template <> struct Vec { + using Type = uchar2; +}; +template <> struct Vec { + using Type = uchar4; +}; +template <> struct Vec { + using Type = Uchar8_; +}; + +// General case: not uchar +template inline constexpr bool is_uchar() { return false; } + +// Specialization: T is uchar +template <> inline constexpr bool is_uchar() { return true; } + +// Generic fallback – will fail to compile if a required specialisation is +// missing. +template +inline Vec fp8_convert(const thread Quant_vec &, float scale) { + static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation"); +} + +// ========================================== FP8 → float/half/bfloat +inline float __dequant_single(uchar v, float scale) { + return fp8_e4m3_to_float(v) * scale; +} + +// ---- 1‑lane ---- +template <> +inline float fp8_convert(const thread uchar &in, float scale) { + return __dequant_single(in, scale); +} +template <> +inline half fp8_convert(const thread uchar &in, float scale) { + return half(__dequant_single(in, scale)); +} +template <> +inline bfloat16_t fp8_convert(const thread uchar &in, + float scale) { + return bfloat16_t(__dequant_single(in, scale)); +} + +// ---- 2‑lane ---- +template <> +inline float2 fp8_convert(const thread uchar2 &in, + float scale) { + return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale)); +} +template <> +inline half2 fp8_convert(const thread uchar2 &in, float scale) { + half2 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + return out; +} +template <> +inline Bfloat2_ fp8_convert(const thread uchar2 &in, + float scale) { + Bfloat2_ out; + out.x = bfloat16_t(__dequant_single(in.x, scale)); + out.y = bfloat16_t(__dequant_single(in.y, scale)); + return out; +} + +// ---- 4‑lane ---- +template <> +inline float4 fp8_convert(const thread uchar4 &in, + float scale) { + return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale), + __dequant_single(in.z, scale), __dequant_single(in.w, scale)); +} +template <> +inline half4 fp8_convert(const thread uchar4 &in, float scale) { + half4 out; + out.x = half(__dequant_single(in.x, scale)); + out.y = half(__dequant_single(in.y, scale)); + out.z = half(__dequant_single(in.z, scale)); + out.w = half(__dequant_single(in.w, scale)); + return out; +} +template <> +inline Bfloat4_ fp8_convert(const thread uchar4 &in, + float scale) { + Bfloat4_ out; + out.x.x = bfloat16_t(__dequant_single(in.x, scale)); + out.x.y = bfloat16_t(__dequant_single(in.y, scale)); + out.y.x = bfloat16_t(__dequant_single(in.z, scale)); + out.y.y = bfloat16_t(__dequant_single(in.w, scale)); + return out; +} + +// ---- 8‑lane ---- +template <> +inline Float8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Float8_ out; + out.x = + float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale), + __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale)); + out.y = + float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale), + __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale)); + return out; +} +template <> +inline Half8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Half8_ out; + out.x = half4(half(__dequant_single(in.x.x, scale)), + half(__dequant_single(in.x.y, scale)), + half(__dequant_single(in.x.z, scale)), + half(__dequant_single(in.x.w, scale))); + out.y = half4(half(__dequant_single(in.y.x, scale)), + half(__dequant_single(in.y.y, scale)), + half(__dequant_single(in.y.z, scale)), + half(__dequant_single(in.y.w, scale))); + return out; +} +template <> +inline Bfloat8_ fp8_convert(const thread Uchar8_ &in, + float scale) { + Bfloat8_ out; + // first 4 + out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale)); + out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale)); + out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale)); + out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale)); + // second 4 + out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale)); + out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale)); + out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale)); + out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale)); + return out; +} + +// ========================================== Dot product utilities + +// TODO(EricLBuehler): optimize with vectorization +template +inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { + // Compute the parallel products for Q*K^T (treat vector lanes separately). + using A_vec = typename FloatVec::Type; + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += simd_shuffle_xor(qk, mask); + } + return qk; +} + +template struct Qk_dot { + template + static inline float dot(const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +// ========================================== Block sum utility + +// Utility function for attention softmax. +template +inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, + uint simd_lid) { + // Compute the sum per simdgroup. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Simd leaders store the data to shared memory. + if (simd_lid == 0) { + red_smem[simd_tid] = sum; + } + + // Make sure the data is in shared memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The warps compute the final sums. + if (simd_lid < NUM_WARPS) { + sum = red_smem[simd_lid]; + } + + // Parallel reduction inside the simd group. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Broadcast to other threads. + return simd_shuffle(sum, 0); +} + +// ========================================== Paged Attention kernel + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +constant bool use_partitioning [[function_constant(10)]]; +constant bool use_alibi [[function_constant(20)]]; +constant bool use_fp8_scales [[function_constant(30)]]; + +template +[[kernel]] void paged_attention( + device float *exp_sums + [[buffer(0)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device float *max_logits + [[buffer(1)]], // [num_seqs, num_heads, max_num_partitions] - only used when + // use_partitioning + device T *out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const CACHE_T *k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const CACHE_T *v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const device float *__restrict__ k_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(7)]], // [1] - only used when use_fp8_scales + const constant int &num_kv_heads [[buffer(8)]], // [num_heads] + const constant float &scale [[buffer(9)]], + const constant float &softcapping [[buffer(10)]], + device const uint32_t *block_tables + [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t *context_lens [[buffer(12)]], // [num_seqs] + const constant int &max_num_blocks_per_seq [[buffer(13)]], + device const float *alibi_slopes + [[buffer(14)]], // [num_heads] - only used when use_alibi + const constant int &q_stride [[buffer(15)]], + const constant int &kv_block_stride [[buffer(16)]], + const constant int &kv_head_stride [[buffer(17)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int seq_idx = threadgroup_position_in_grid.y; + const int partition_idx = threadgroup_position_in_grid.z; + const int max_num_partitions = threadgroups_per_grid.z; + const int thread_idx = thread_position_in_threadgroup.x; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const uint32_t context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES); + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + const int head_idx = threadgroup_position_in_grid.x; + const int num_heads = threadgroups_per_grid.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. + const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use fp32 on softmax logits for better accuracy + threadgroup float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction + threadgroup float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(CACHE_T); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const device uint32_t *block_table = + block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const device CACHE_T *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (is_uchar()) { + // FP8 support + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8_convert(k_vec_quant, *k_scale); + } else { + // Non-FP8 default + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = precise::tanh(qk / softcapping) * softcapping; + } + + // Add the ALiBi bias if slopes are given. + if (use_alibi && alibi_slope != 0) { + // Compute bias with explicit float precision to minimize precision loss + int position_offset = token_idx - int(context_len) + 1; + float alibi_bias = alibi_slope * float(position_offset); + qk += alibi_bias; + } + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE: It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : max(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = simd_shuffle(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, + simd_tid, simd_lid); + + // Compute softmax. + const float inv_sum = divide(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) { + device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + using V_quant_vec = typename Vec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE: We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + T zero_value = 0; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); + from_float(logits_vec, logits_float_vec); + + const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // NOTE: When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec; + + if constexpr (is_uchar()) { + // FP8 support + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + v_vec = fp8_convert(v_quant_vec, *v_scale); + } else { + // Non-FP8 default + v_vec = *reinterpret_cast(v_ptr + offset); + } + + if (block_idx == num_context_blocks - 1) { + thread T *v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += simd_shuffle_xor(acc, mask); + } + accs[i] = acc; + } + + // NOTE: A barrier is required because the shared memory space for logits + // is reused for the output. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Perform reduction across warps. + threadgroup float *out_smem = + reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Lower warps update the output. + if (warp_idx < mid) { + const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write the final output. + if (warp_idx == 0) { + device T *out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + *(out_ptr + row_idx) = T(accs[i]); + } + } + } +} + +template +[[kernel]] void paged_attention_v2_reduce( + device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]], + const device float *max_logits [[buffer(2)]], + const device T *tmp_out [[buffer(3)]], + device uint32_t *context_lens [[buffer(4)]], + const constant int &max_num_partitions [[buffer(5)]], + threadgroup char *shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int num_heads = threadgroups_per_grid.x; + const int head_idx = threadgroup_position_in_grid.x; + const int seq_idx = threadgroup_position_in_grid.y; + const uint32_t context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += threads_per_threadgroup.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + // Workspace for reduction. + threadgroup float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + threadgroup float *shared_max_logits = + reinterpret_cast(shared_mem); + const device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = max(max_logit, l); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = simd_shuffle(max_logit, 0); + + // Load rescaled exp sums to shared memory. + threadgroup float *shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + out_ptr[i] = T(acc); + } +} + +#define instantiate_paged_attention_inner(type, cache_type, head_size, \ + block_size, num_threads, \ + num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_" #type "_cache_" #cache_type \ + "_hs" #head_size "_bs" #block_size "_nt" #num_threads \ + "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention( \ + device float *exp_sums [[buffer(0)]], \ + device float *max_logits [[buffer(1)]], \ + device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \ + device const cache_type *k_cache [[buffer(4)]], \ + device const cache_type *v_cache [[buffer(5)]], \ + const device float *__restrict__ k_scale [[buffer(6)]], \ + const device float *__restrict__ v_scale [[buffer(7)]], \ + const constant int &num_kv_heads [[buffer(8)]], \ + const constant float &scale [[buffer(9)]], \ + const constant float &softcapping [[buffer(10)]], \ + device const uint32_t *block_tables [[buffer(11)]], \ + device const uint32_t *context_lens [[buffer(12)]], \ + const constant int &max_num_blocks_per_seq [[buffer(13)]], \ + device const float *alibi_slopes [[buffer(14)]], \ + const constant int &q_stride [[buffer(15)]], \ + const constant int &kv_block_stride [[buffer(16)]], \ + const constant int &kv_head_stride [[buffer(17)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce( \ + device type * out [[buffer(0)]], \ + const device float *exp_sums [[buffer(1)]], \ + const device float *max_logits [[buffer(2)]], \ + const device type *tmp_out [[buffer(3)]], \ + device uint32_t *context_lens [[buffer(4)]], \ + const constant int &max_num_partitions [[buffer(5)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads( \ + type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, cache_type, 32, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 64, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 80, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 96, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 112, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 120, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 128, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 192, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); \ + instantiate_paged_attention_inner(type, cache_type, 256, block_size, \ + num_threads, num_simd_lanes, \ + partition_size); + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner(type, 32, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 120, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \ + num_simd_lanes, partition_size); + +#define instantiate_paged_attention_block_size(type, cache_type, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \ + num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \ + num_simd_lanes, partition_size); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 0 +#define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 0); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, cache_type, 256, \ + num_simd_lanes, 512); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512); + +instantiate_paged_attention_v1(float, float, 32); +instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v1(half, half, 32); + +instantiate_paged_attention_v1(float, uchar, 32); +instantiate_paged_attention_v1(bfloat16_t, uchar, 32); +instantiate_paged_attention_v1(half, uchar, 32); + +instantiate_paged_attention_v2_reduce(float, 32); +instantiate_paged_attention_v2_reduce(bfloat16_t, 32); +instantiate_paged_attention_v2_reduce(half, 32); + +instantiate_paged_attention_v2(float, float, 32); +instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32); +instantiate_paged_attention_v2(half, half, 32); + +instantiate_paged_attention_v2(float, uchar, 32); +instantiate_paged_attention_v2(bfloat16_t, uchar, 32); +instantiate_paged_attention_v2(half, uchar, 32); diff --git a/paged-attention-metal/cache.mm b/paged-attention-metal/cache.mm new file mode 100644 index 0000000000000000000000000000000000000000..cf6726000a7206b8c96ed5fb0ebaa4cf09b5ba26 --- /dev/null +++ b/paged-attention-metal/cache.mm @@ -0,0 +1,562 @@ +#include +#include +#include + +#import +#import +#include +#include +#include + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +static std::string getModuleDirectory() { + Dl_info dl_info; + if (dladdr((void *)getModuleDirectory, &dl_info)) { + std::string path(dl_info.dli_fname); + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(0, pos); + } + } + return "."; +} + +void swap_blocks(torch::Tensor &src, torch::Tensor &dst, + const torch::Tensor &block_mapping) { + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const int64_t num_blocks = block_mapping.size(0); + + // Handle different device combinations + if (src.device().is_mps() && dst.device().is_mps()) { + // MPS to MPS: Use Metal blit encoder + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id commandBuffer = stream->commandBuffer(); + TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer"); + + dispatch_queue_t serialQueue = stream->queue(); + + dispatch_sync(serialQueue, ^{ + id blitEncoder = + [commandBuffer blitCommandEncoder]; + TORCH_CHECK(blitEncoder, "Failed to create blit command encoder"); + + id srcBuf = getMTLBufferStorage(src); + id dstBuf = getMTLBufferStorage(dst); + + for (int64_t i = 0; i < num_blocks; ++i) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + NSUInteger src_offset = src_block_number * block_size_in_bytes; + NSUInteger dst_offset = dst_block_number * block_size_in_bytes; + + [blitEncoder copyFromBuffer:srcBuf + sourceOffset:src_offset + toBuffer:dstBuf + destinationOffset:dst_offset + size:block_size_in_bytes]; + } + + [blitEncoder endEncoding]; + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } + } else { + // Cross-device transfers (MPS-CPU, CPU-MPS, CPU-CPU): Use PyTorch's copy + for (int64_t i = 0; i < num_blocks; ++i) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + + // Copy the entire block + dst[dst_block_number].copy_(src[src_block_number]); + } + } +} + +void copy_blocks(const std::vector &key_caches, + const std::vector &value_caches, + const torch::Tensor &block_mapping) { + const int64_t num_layers = key_caches.size(); + TORCH_CHECK(num_layers == static_cast(value_caches.size()), + "key_caches and value_caches must have the same length"); + if (num_layers == 0) { + return; + } + + // --- Preconditions -------------------------------------------------- + torch::Device dev = key_caches[0].device(); + TORCH_CHECK(dev.is_mps(), "copy_blocks: expected MPS tensors"); + + // Move block_mapping to CPU if it's on MPS + torch::Tensor block_mapping_cpu = block_mapping; + if (block_mapping.device().is_mps()) { + block_mapping_cpu = block_mapping.cpu(); + } + + for (int64_t i = 0; i < num_layers; ++i) { + TORCH_CHECK(key_caches[i].device() == dev && + value_caches[i].device() == dev, + "All cache tensors must be on the same MPS device"); + TORCH_CHECK(key_caches[i].dtype() == value_caches[i].dtype(), + "Key/value cache dtype mismatch at layer ", i); + } + + const int64_t num_pairs = block_mapping.size(0); + const int32_t numel_per_block = + static_cast(key_caches[0][0].numel()); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + // Process each layer separately + for (int64_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + NSString *kernName = nil; + switch (key_caches[layer_idx].scalar_type()) { + case torch::kFloat: + kernName = @"copy_blocks_float"; + break; + case torch::kHalf: + kernName = @"copy_blocks_half"; + break; + case torch::kBFloat16: + kernName = @"copy_blocks_bfloat16_t"; + break; + case torch::kUInt8: + kernName = @"copy_blocks_uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for copy_blocks"); + } + + id fn = [lib newFunctionWithName:kernName]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set key and value cache buffers + [enc setBuffer:getMTLBufferStorage(key_caches[layer_idx]) + offset:key_caches[layer_idx].storage_offset() * + key_caches[layer_idx].element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value_caches[layer_idx]) + offset:value_caches[layer_idx].storage_offset() * + value_caches[layer_idx].element_size() + atIndex:1]; + + // Set block mapping buffer + id mappingBuf = + [device newBufferWithBytes:block_mapping_cpu.data_ptr() + length:num_pairs * 2 * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:mappingBuf offset:0 atIndex:2]; + + // Set numel_per_block as buffer + id numelBuf = + [device newBufferWithBytes:&numel_per_block + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numelBuf offset:0 atIndex:3]; + + const uint32_t threadsPerThreadgroup = + std::min(256, numel_per_block); + MTLSize tg = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(threadsPerThreadgroup * num_pairs, 1, 1); + + [enc dispatchThreads:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + } + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} + +void reshape_and_cache( + torch::Tensor &key, // [num_tokens, num_heads, head_size] + torch::Tensor &value, // [num_tokens, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor &slot_mapping, // [num_tokens] + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale) { + + // Determine cache dtype and FP8 usage + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + TORCH_CHECK(key.device().is_mps() && value.device().is_mps() && + key_cache.device().is_mps() && value_cache.device().is_mps(), + "All tensors must be on MPS device"); + + // Move slot_mapping to CPU if it's on MPS + torch::Tensor slot_mapping_cpu = slot_mapping; + if (slot_mapping.device().is_mps()) { + slot_mapping_cpu = slot_mapping.cpu(); + } + + const int64_t num_tokens = key.size(0); + const int64_t num_heads = key.size(1); + const int64_t head_size = key.size(2); + const int64_t block_size = key_cache.size(3); + const int64_t x = key_cache.size(4); + + const int32_t key_stride = key.stride(0); + const int32_t value_stride = value.stride(0); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + NSString *kernName = nil; + std::string kv_dtype_str, cache_dtype_str; + + // Get KV dtype string + switch (key.scalar_type()) { + case torch::kFloat: + kv_dtype_str = "float"; + break; + case torch::kHalf: + kv_dtype_str = "half"; + break; + case torch::kBFloat16: + kv_dtype_str = "bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache"); + } + + // Get cache dtype string + switch (cache_dtype) { + case torch::kFloat: + cache_dtype_str = "float"; + break; + case torch::kHalf: + cache_dtype_str = "half"; + break; + case torch::kBFloat16: + cache_dtype_str = "bfloat16_t"; + break; + case torch::kUInt8: + cache_dtype_str = "uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported cache dtype for reshape_and_cache"); + } + + std::string kernName_str = "reshape_and_cache_kv_" + kv_dtype_str + "_cache_" + cache_dtype_str; + kernName = [NSString stringWithUTF8String:kernName_str.c_str()]; + + // Create function constants for FP8 support + MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init]; + [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:10]; + + id fn = [lib newFunctionWithName:kernName constantValues:constants error:&error]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String, + error ? [NSString stringWithFormat:@": %@", error.localizedDescription].UTF8String : ""); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set tensor buffers + [enc setBuffer:getMTLBufferStorage(key) + offset:key.storage_offset() * key.element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value) + offset:value.storage_offset() * value.element_size() + atIndex:1]; + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:2]; + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:3]; + + // Set slot mapping buffer + id slotMappingBuf = + [device newBufferWithBytes:slot_mapping_cpu.data_ptr() + length:num_tokens * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:slotMappingBuf offset:0 atIndex:4]; + + // k_scale and v_scale buffers (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:5]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:6]; + } else { + // For non-FP8, we still need to increment buffer indices + // The Metal kernel expects buffers at indices 5 and 6 even if unused + } + + // Set parameters as individual buffers (matching mistralrs pattern) + id keyStrideBuf = + [device newBufferWithBytes:&key_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:keyStrideBuf offset:0 atIndex:7]; + + id valueStrideBuf = + [device newBufferWithBytes:&value_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:valueStrideBuf offset:0 atIndex:8]; + + const int32_t num_heads_i32 = static_cast(num_heads); + id numHeadsBuf = + [device newBufferWithBytes:&num_heads_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numHeadsBuf offset:0 atIndex:9]; + + const int32_t head_size_i32 = static_cast(head_size); + id headSizeBuf = + [device newBufferWithBytes:&head_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:headSizeBuf offset:0 atIndex:10]; + + const int32_t block_size_i32 = static_cast(block_size); + id blockSizeBuf = + [device newBufferWithBytes:&block_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:blockSizeBuf offset:0 atIndex:11]; + + const int32_t x_i32 = static_cast(x); + id xBuf = + [device newBufferWithBytes:&x_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:xBuf offset:0 atIndex:12]; + + const uint64_t threads_per_threadgroup = + std::min(512, num_heads * head_size); + MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1); + MTLSize grid = MTLSizeMake(num_tokens, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} + +void reshape_and_cache_flash( + torch::Tensor &key, // [num_tokens, num_heads, head_size] + torch::Tensor &value, // [num_tokens, num_heads, head_size] + torch::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor + &value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor &slot_mapping, // [num_tokens] + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale) { + + TORCH_CHECK(key.device().is_mps() && value.device().is_mps() && + key_cache.device().is_mps() && value_cache.device().is_mps(), + "All tensors must be on MPS device"); + + // Move slot_mapping to CPU if it's on MPS + torch::Tensor slot_mapping_cpu = slot_mapping; + if (slot_mapping.device().is_mps()) { + slot_mapping_cpu = slot_mapping.cpu(); + } + + const int64_t num_tokens = key.size(0); + const int64_t num_heads = key.size(1); + const int64_t head_size = key.size(2); + const int64_t block_size = key_cache.size(1); + + const int32_t key_stride = key.stride(0); + const int32_t value_stride = value.stride(0); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Construct the full path to the metallib file + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + if (!lib) { + NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@", + metallibPathStr, error.localizedDescription); + } + + NSString *kernName = nil; + switch (key.scalar_type()) { + case torch::kFloat: + kernName = @"reshape_and_cache_flash_float"; + break; + case torch::kHalf: + kernName = @"reshape_and_cache_flash_half"; + break; + case torch::kBFloat16: + kernName = @"reshape_and_cache_flash_bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache_flash"); + } + + id fn = [lib newFunctionWithName:kernName]; + TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, error.localizedDescription.UTF8String); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set tensor buffers + [enc setBuffer:getMTLBufferStorage(key) + offset:key.storage_offset() * key.element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(value) + offset:value.storage_offset() * value.element_size() + atIndex:1]; + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:2]; + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:3]; + + // Set slot mapping buffer + id slotMappingBuf = + [device newBufferWithBytes:slot_mapping_cpu.data_ptr() + length:num_tokens * sizeof(int64_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:slotMappingBuf offset:0 atIndex:4]; + + // Set parameters as individual buffers + id keyStrideBuf = + [device newBufferWithBytes:&key_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:keyStrideBuf offset:0 atIndex:5]; + + id valueStrideBuf = + [device newBufferWithBytes:&value_stride + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:valueStrideBuf offset:0 atIndex:6]; + + const int32_t num_heads_i32 = static_cast(num_heads); + id numHeadsBuf = + [device newBufferWithBytes:&num_heads_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numHeadsBuf offset:0 atIndex:7]; + + const int32_t head_size_i32 = static_cast(head_size); + id headSizeBuf = + [device newBufferWithBytes:&head_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:headSizeBuf offset:0 atIndex:8]; + + const int32_t block_size_i32 = static_cast(block_size); + id blockSizeBuf = + [device newBufferWithBytes:&block_size_i32 + length:sizeof(int32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:blockSizeBuf offset:0 atIndex:9]; + + const uint64_t threads_per_threadgroup = + std::min(512, num_heads * head_size); + MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1); + MTLSize grid = MTLSizeMake(num_tokens, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + }); + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} \ No newline at end of file diff --git a/paged-attention-metal/cache/copy_blocks.metal b/paged-attention-metal/cache/copy_blocks.metal new file mode 100644 index 0000000000000000000000000000000000000000..31595cf80a8689e2daf4228ad56de78fd4d156e4 --- /dev/null +++ b/paged-attention-metal/cache/copy_blocks.metal @@ -0,0 +1,51 @@ +#include "../utils.metal" +#include + +using namespace metal; + +template +[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]], + device T *value_cache [[buffer(1)]], + const device int64_t *block_mapping [[buffer(2)]], + device const int &numel_per_block, + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup + [[threads_per_threadgroup]]) { + const int pair_idx = tgid; + + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + + // Copy key cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + + // Copy value cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +#define instantiate_copy_blocks(type) \ + template [[host_name("copy_blocks_" #type)]] [[kernel]] void \ + copy_blocks(device type * key_cache [[buffer(0)]], \ + device type * value_cache [[buffer(1)]], \ + const device int64_t *block_mapping [[buffer(2)]], \ + device const int &numel_per_block, \ + uint tgid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_copy_blocks(float); +instantiate_copy_blocks(bfloat16_t); +instantiate_copy_blocks(half); +instantiate_copy_blocks(uchar); diff --git a/paged-attention-metal/cache/reshape_and_cache.metal b/paged-attention-metal/cache/reshape_and_cache.metal new file mode 100644 index 0000000000000000000000000000000000000000..28ff7dbebaf30491ff3d16c02874fce7380a33b6 --- /dev/null +++ b/paged-attention-metal/cache/reshape_and_cache.metal @@ -0,0 +1,193 @@ +#include "../utils.metal" +#include "../float8.metal" +#include + +using namespace metal; + +template +inline CACHE_T to_cache(KV_T v) = delete; + +template <> inline uchar to_cache(float v) { + return float_to_fp8_e4m3(v); +} + +template <> inline uchar to_cache(bfloat16_t v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline uchar to_cache(half v) { + return float_to_fp8_e4m3((float)v); +} + +template <> inline float to_cache(float v) { return v; } + +template <> inline bfloat16_t to_cache(bfloat16_t v) { + return v; +} + +template <> inline half to_cache(half v) { return v; } + +constant bool use_fp8_scales [[function_constant(10)]]; + +template +[[kernel]] void reshape_and_cache( + const device KV_T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device KV_T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device CACHE_T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] + device CACHE_T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + const device float *__restrict__ k_scale + [[buffer(5)]], // [1] - only used when use_fp8_scales + const device float *__restrict__ v_scale + [[buffer(6)]], // [1] - only used when use_fp8_scales + device const int &key_stride [[buffer(7)]], + device const int &value_stride [[buffer(8)]], + device const int &num_heads [[buffer(9)]], + device const int &head_size [[buffer(10)]], + device const int &block_size [[buffer(11)]], + device const int &x [[buffer(12)]], + uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + + if (use_fp8_scales) { + key_cache[tgt_key_idx] = + to_cache(KV_T((float)key[src_key_idx] / *k_scale)); + value_cache[tgt_value_idx] = + to_cache(KV_T((float)value[src_value_idx] / *v_scale)); + } else { + key_cache[tgt_key_idx] = to_cache(key[src_key_idx]); + value_cache[tgt_value_idx] = to_cache(value[src_value_idx]); + } + } +} + +#define instantiate_reshape_and_cache(kv_type, cache_type) \ + template [[host_name("reshape_and_cache_kv_" #kv_type \ + "_cache_" #cache_type)]] [[kernel]] void \ + reshape_and_cache( \ + const device kv_type *__restrict__ key [[buffer(0)]], \ + const device kv_type *__restrict__ value [[buffer(1)]], \ + device cache_type *__restrict__ key_cache [[buffer(2)]], \ + device cache_type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + const device float *__restrict__ k_scale [[buffer(5)]], \ + const device float *__restrict__ v_scale [[buffer(6)]], \ + device const int &key_stride [[buffer(7)]], \ + device const int &value_stride [[buffer(8)]], \ + device const int &num_heads [[buffer(9)]], \ + device const int &head_size [[buffer(10)]], \ + device const int &block_size [[buffer(11)]], \ + device const int &x [[buffer(12)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache(float, float); +instantiate_reshape_and_cache(bfloat16_t, bfloat16_t); +instantiate_reshape_and_cache(half, half); + +instantiate_reshape_and_cache(float, uchar); +instantiate_reshape_and_cache(bfloat16_t, uchar); +instantiate_reshape_and_cache(half, uchar); + +// Flash version with different cache layout: [num_blocks, block_size, +// num_heads, head_size] +template +[[kernel]] void reshape_and_cache_flash( + const device T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, block_size, num_heads, head_size] + device T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, block_size, num_heads, head_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + device const int &key_stride, device const int &value_stride, + device const int &num_heads, device const int &head_size, + device const int &block_size, uint gid [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup [[threads_per_threadgroup]]) { + const int64_t token_idx = gid; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = tid; i < n; i += threads_per_threadgroup) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + // Flash cache layout: [num_blocks, block_size, num_heads, head_size] + const int64_t tgt_key_idx = block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + const int64_t tgt_value_idx = + block_idx * block_size * num_heads * head_size + + block_offset * num_heads * head_size + head_idx * head_size + + head_offset; + key_cache[tgt_key_idx] = key[src_key_idx]; + value_cache[tgt_value_idx] = value[src_value_idx]; + } +} + +#define instantiate_reshape_and_cache_flash(type) \ + template [[host_name("reshape_and_cache_flash_" #type)]] [[kernel]] void \ + reshape_and_cache_flash( \ + const device type *__restrict__ key [[buffer(0)]], \ + const device type *__restrict__ value [[buffer(1)]], \ + device type *__restrict__ key_cache [[buffer(2)]], \ + device type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + device const int &key_stride, device const int &value_stride, \ + device const int &num_heads, device const int &head_size, \ + device const int &block_size, uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache_flash(float); +instantiate_reshape_and_cache_flash(bfloat16_t); +instantiate_reshape_and_cache_flash(half); diff --git a/paged-attention-metal/convert_fp8.metal b/paged-attention-metal/convert_fp8.metal new file mode 100644 index 0000000000000000000000000000000000000000..22028ce705766bf536e7bd7748750eaf59ce26e7 --- /dev/null +++ b/paged-attention-metal/convert_fp8.metal @@ -0,0 +1,77 @@ +#include "float8.metal" +#include "utils.metal" +#include + +using namespace metal; + +// Convert between different precision formats for cache tensors +// This kernel handles conversions like float->fp8, fp8->float, etc. + +template +[[kernel]] void convert_fp8_kernel( + const device SRC_T *__restrict__ src [[buffer(0)]], + device DST_T *__restrict__ dst [[buffer(1)]], + const device float &scale [[buffer(2)]], + const device uint32_t &num_elements [[buffer(3)]], + uint gid [[thread_position_in_grid]]) { + + if (gid >= num_elements) { + return; + } + + // Load source value + SRC_T src_val = src[gid]; + + // Convert based on source and destination types + if constexpr (is_same_v && !is_same_v) { + // FP8 -> higher precision (dequantization) + float fp32_val = fp8_e4m3_to_float(src_val) * scale; + dst[gid] = static_cast(fp32_val); + } else if constexpr (!is_same_v && is_same_v) { + // Higher precision -> FP8 (quantization) + float fp32_val = static_cast(src_val) / scale; + dst[gid] = float_to_fp8_e4m3(fp32_val); + } else if constexpr (is_same_v && is_same_v) { + // FP8 -> FP8 (with rescaling) + float fp32_val = fp8_e4m3_to_float(src_val) * scale; + dst[gid] = float_to_fp8_e4m3(fp32_val); + } else { + // Regular precision -> regular precision (with scaling) + float fp32_val = static_cast(src_val) * scale; + dst[gid] = static_cast(fp32_val); + } +} + +// Instantiate all required combinations +#define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \ + template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \ + [[kernel]] void convert_fp8_kernel( \ + const device src_type *__restrict__ src [[buffer(0)]], \ + device dst_type *__restrict__ dst [[buffer(1)]], \ + const device float &scale [[buffer(2)]], \ + const device uint32_t &num_elements [[buffer(3)]], \ + uint gid [[thread_position_in_grid]]); + +// FP8 to other formats (dequantization) +INSTANTIATE_CONVERT_FP8(uchar, float); +INSTANTIATE_CONVERT_FP8(uchar, half); +INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t); + +// Other formats to FP8 (quantization) +INSTANTIATE_CONVERT_FP8(float, uchar); +INSTANTIATE_CONVERT_FP8(half, uchar); +INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar); + +// FP8 to FP8 (rescaling) +INSTANTIATE_CONVERT_FP8(uchar, uchar); + +// Regular precision conversions with scaling +INSTANTIATE_CONVERT_FP8(float, float); +INSTANTIATE_CONVERT_FP8(float, half); +INSTANTIATE_CONVERT_FP8(float, bfloat16_t); +INSTANTIATE_CONVERT_FP8(half, float); +INSTANTIATE_CONVERT_FP8(half, half); +INSTANTIATE_CONVERT_FP8(half, bfloat16_t); +INSTANTIATE_CONVERT_FP8(bfloat16_t, float); +INSTANTIATE_CONVERT_FP8(bfloat16_t, half); +INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t); \ No newline at end of file diff --git a/paged-attention-metal/convert_fp8.mm b/paged-attention-metal/convert_fp8.mm new file mode 100644 index 0000000000000000000000000000000000000000..e86b71a0b4f36a0eaef27f3665b48ec46544bfd6 --- /dev/null +++ b/paged-attention-metal/convert_fp8.mm @@ -0,0 +1,138 @@ +#include +#include +#include + +#import +#import +#include +#include +#include +#include +#include + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +static std::string getModuleDirectory() { + Dl_info dl_info; + if (dladdr((void *)getModuleDirectory, &dl_info)) { + std::string path(dl_info.dli_fname); + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(0, pos); + } + } + return "."; +} + +// Helper function to get conversion kernel name +static std::string getConvertKernelName(torch::ScalarType src_dtype, torch::ScalarType dst_dtype) { + std::string src_str, dst_str; + + auto dtype_to_string = [](torch::ScalarType dtype) -> std::string { + switch (dtype) { + case torch::kFloat: return "float"; + case torch::kHalf: return "half"; + case torch::kBFloat16: return "bfloat16_t"; + case torch::kUInt8: return "uchar"; + default: + TORCH_CHECK(false, "Unsupported dtype for convert_fp8: ", dtype); + } + }; + + src_str = dtype_to_string(src_dtype); + dst_str = dtype_to_string(dst_dtype); + + return "convert_fp8_" + src_str + "_to_" + dst_str; +} + +void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache, + const double scale, const std::string &kv_cache_dtype) { + // Validate input tensors + TORCH_CHECK(src_cache.device().is_mps() && dst_cache.device().is_mps(), + "Both tensors must be on MPS device"); + TORCH_CHECK(src_cache.device() == dst_cache.device(), + "Source and destination tensors must be on the same device"); + TORCH_CHECK(src_cache.numel() == dst_cache.numel(), + "Source and destination tensors must have the same number of elements"); + TORCH_CHECK(src_cache.is_contiguous() && dst_cache.is_contiguous(), + "Both tensors must be contiguous"); + + const uint32_t num_elements = static_cast(src_cache.numel()); + if (num_elements == 0) { + return; // Nothing to convert + } + + // Determine conversion kernel name + std::string kernel_name = getConvertKernelName(src_cache.scalar_type(), dst_cache.scalar_type()); + + @autoreleasepool { + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id device = stream->device(); + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get command buffer"); + + // Load Metal library + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + NSString *metallibPathStr = [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ", + error ? error.localizedDescription.UTF8String : "unknown error"); + + // Create kernel function + NSString *kernelNameStr = [NSString stringWithUTF8String:kernel_name.c_str()]; + id fn = [lib newFunctionWithName:kernelNameStr]; + TORCH_CHECK(fn, "Failed to find Metal kernel function: ", kernel_name); + + id pso = [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, "Failed to create compute pipeline state: ", + error ? error.localizedDescription.UTF8String : "unknown error"); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute encoder"); + + [enc setComputePipelineState:pso]; + + // Set buffers + [enc setBuffer:getMTLBufferStorage(src_cache) + offset:src_cache.storage_offset() * src_cache.element_size() + atIndex:0]; + [enc setBuffer:getMTLBufferStorage(dst_cache) + offset:dst_cache.storage_offset() * dst_cache.element_size() + atIndex:1]; + + // Set scale parameter + float scale_f32 = static_cast(scale); + id scaleBuf = [device newBufferWithBytes:&scale_f32 + length:sizeof(float) + options:MTLResourceStorageModeShared]; + [enc setBuffer:scaleBuf offset:0 atIndex:2]; + + // Set num_elements parameter + id numElementsBuf = [device newBufferWithBytes:&num_elements + length:sizeof(uint32_t) + options:MTLResourceStorageModeShared]; + [enc setBuffer:numElementsBuf offset:0 atIndex:3]; + + // Dispatch threads + const uint32_t threads_per_threadgroup = std::min(1024, num_elements); + const uint32_t threadgroups = (num_elements + threads_per_threadgroup - 1) / threads_per_threadgroup; + + MTLSize threadsPerThreadgroup = MTLSizeMake(threads_per_threadgroup, 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1); + + [enc dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + [enc endEncoding]; + }); + + stream->synchronize(at::mps::SyncType::COMMIT); + } +} \ No newline at end of file diff --git a/paged-attention-metal/device.mm b/paged-attention-metal/device.mm new file mode 100644 index 0000000000000000000000000000000000000000..4d2e541bb64f1ec4baa75ef27e3fd14a745391bb --- /dev/null +++ b/paged-attention-metal/device.mm @@ -0,0 +1,17 @@ +#include "../torch-ext/torch_binding.h" +#import +#include + +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { + TORCH_CHECK(false, "get_device_attribute is not supported on Metal"); +} + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + // On macOS you can have multiple GPUs; fetch the N-th one. + NSArray> *all = MTLCopyAllDevices(); + TORCH_CHECK(device_id >= 0 && device_id < (int64_t)all.count, + "Invalid Metal device index"); + + id dev = all[device_id]; + return static_cast(dev.maxThreadgroupMemoryLength); +} \ No newline at end of file diff --git a/paged-attention-metal/float8.metal b/paged-attention-metal/float8.metal new file mode 100644 index 0000000000000000000000000000000000000000..b911eba237fec1ce6ea7f86653b12c4a5eec0000 --- /dev/null +++ b/paged-attention-metal/float8.metal @@ -0,0 +1,122 @@ +#include +using namespace metal; + +// Helpers ------------------------------------------------------------ +static inline uint as_bits(float x) { return as_type(x); } +static inline float from_bits(uint b) { return as_type(b); } + +// ------------------------------------------------------------------- +// FP8 E4M3 (bias = 7) +// ------------------------------------------------------------------- +inline float fp8_e4m3_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 3) & 0xF; + const uint man = v & 0x7; + + if (exp == 0) { // zero / sub-normal + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 8.f; // already scaled by 2^-3 + float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6 + return s ? -val : val; + } + + if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN) + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 8.f; + float val = ldexp(m, int(exp) - 7); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// FP8 E5M2 (bias = 15) +// ------------------------------------------------------------------- +inline float fp8_e5m2_to_float(uchar v) { + const uint s = v >> 7; + const uint exp = (v >> 2) & 0x1F; + const uint man = v & 0x3; + + if (exp == 0) { + if (man == 0) + return s ? -0.f : 0.f; + const float m = float(man) / 4.f; + float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14 + return s ? -val : val; + } + + if (exp == 0x1F) { + if (man != 0) + return NAN; + return s ? -INFINITY : INFINITY; + } + + const float m = 1.f + float(man) / 4.f; + float val = ldexp(m, int(exp) - 15); + return s ? -val : val; +} + +// ------------------------------------------------------------------- +// Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞) +// ------------------------------------------------------------------- +namespace detail { +template +inline uchar fp32_to_fp8(float f) { + const uint bits = as_bits(f); + const uint s = bits >> 31; + const uint abs = bits & 0x7FFFFFFF; + + // NaN propagates, Inf saturates + if (abs >= 0x7F800000u) { + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) | + (abs != 0x7F800000u)); + } + + int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent + uint m = abs & 0x7FFFFFu; // 23-bit mantissa + const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent + + // ---------- Normal path ------------------------------------------------- + int e_fp8 = e + BIAS; + if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) { + // round-to-nearest-even + const int shift = 23 - MAN_BITS; + uint mant = m >> shift; + const uint lsb = mant & 1u; + const uint round = (m >> (shift - 1)) & 1u; + const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u; + mant += (round & (sticky | lsb)); + if (mant >> MAN_BITS) { // mantissa overflow + mant = 0; + ++e_fp8; + if (e_fp8 > EXP_MAX) + return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞ + } + return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) | + (mant & ((1u << MAN_BITS) - 1u))); + } + + // ---------- Sub-normal / under-flow ------------------------------------ + if (e_fp8 < 1 - MAN_BITS) // too small -> ±0 + return uchar(s << 7); + + // shift so that exponent becomes 1 + int rshift = (1 - e_fp8) + (23 - MAN_BITS); + uint mant = (0x800000u | m); // implicit 1 + uint rounded = (mant + (1u << (rshift - 1))) >> rshift; + if (rounded == 0) + return uchar(s << 7); // rounds to zero + + return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u))); +} +} // namespace detail + +inline uchar float_to_fp8_e4m3(float f) { + return detail::fp32_to_fp8<4, 3, 7>(f); +} +inline uchar float_to_fp8_e5m2(float f) { + return detail::fp32_to_fp8<5, 2, 15>(f); +} \ No newline at end of file diff --git a/paged-attention-metal/paged_attention.mm b/paged-attention-metal/paged_attention.mm new file mode 100644 index 0000000000000000000000000000000000000000..71c398dbe896d0f13c95ba82bdf543d82f4d8c07 --- /dev/null +++ b/paged-attention-metal/paged_attention.mm @@ -0,0 +1,693 @@ +#include +#include +#include + +#import +#import +#include +#include +#include +#include +#include + +static inline id getMTLBufferStorage(const torch::Tensor &tensor) { + return __builtin_bit_cast(id, tensor.storage().data()); +} + +static std::string getModuleDirectory() { + Dl_info dl_info; + if (dladdr((void *)getModuleDirectory, &dl_info)) { + std::string path(dl_info.dli_fname); + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(0, pos); + } + } + return "."; +} + +// Helper function to get kernel name based on dtype and parameters +static std::string getKernelName(const std::string &base_name, + torch::ScalarType dtype, + torch::ScalarType cache_dtype, + int head_size, + int block_size, int num_threads, + int num_simd_lanes, int partition_size = 0) { + std::string dtype_str; + switch (dtype) { + case torch::kFloat: + dtype_str = "float"; + break; + case torch::kHalf: + dtype_str = "half"; + break; + case torch::kBFloat16: + dtype_str = "bfloat16_t"; + break; + default: + TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype); + } + + std::string cache_dtype_str; + switch (cache_dtype) { + case torch::kFloat: + cache_dtype_str = "float"; + break; + case torch::kHalf: + cache_dtype_str = "half"; + break; + case torch::kBFloat16: + cache_dtype_str = "bfloat16_t"; + break; + case torch::kUInt8: + cache_dtype_str = "uchar"; + break; + default: + TORCH_CHECK(false, "Unsupported cache dtype for paged attention: ", cache_dtype); + } + + std::string kernel_name = + base_name + "_" + dtype_str + "_cache_" + cache_dtype_str + "_hs" + std::to_string(head_size) + "_bs" + + std::to_string(block_size) + "_nt" + std::to_string(num_threads) + + "_nsl" + std::to_string(num_simd_lanes); + + if (partition_size >= 0) { + kernel_name += "_ps" + std::to_string(partition_size); + } + + return kernel_name; +} + +// Helper function to calculate shared memory size +static size_t calculateSharedMemorySize(int max_seq_len, int head_size, + int num_threads, int num_simd_lanes) { + // Logits storage: max_seq_len * sizeof(float) + size_t logits_size = max_seq_len * sizeof(float); + + // Reduction workspace: 2 * (num_threads / num_simd_lanes) * sizeof(float) + size_t reduction_size = 2 * (num_threads / num_simd_lanes) * sizeof(float); + + // Output workspace for cross-warp reduction: head_size * sizeof(float) + size_t output_size = head_size * sizeof(float); + return std::max(logits_size + reduction_size, output_size); +} + +// Helper function to get supported configurations +static bool isValidConfiguration(int head_size, int block_size) { + // Supported head sizes from the Metal kernel instantiations + std::vector supported_head_sizes = {32, 64, 80, 96, 112, + 120, 128, 192, 256}; + std::vector supported_block_sizes = {8, 16, 32}; + + return std::find(supported_head_sizes.begin(), supported_head_sizes.end(), + head_size) != supported_head_sizes.end() && + std::find(supported_block_sizes.begin(), supported_block_sizes.end(), + block_size) != supported_block_sizes.end(); +} + +void paged_attention_v1( + torch::Tensor &out, // [num_seqs, num_heads, head_size] + torch::Tensor &query, // [num_seqs, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional &alibi_slopes, + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + // Validate block sparse is not supported yet + // TODO: support blocksparse. + TORCH_CHECK( + !is_block_sparse, + "Block sparse attention is not yet supported in Metal implementation"); + + // Determine cache dtype based on kv_cache_dtype + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + // Validate input tensors + TORCH_CHECK(out.device().is_mps() && query.device().is_mps() && + key_cache.device().is_mps() && + value_cache.device().is_mps() && + block_tables.device().is_mps() && seq_lens.device().is_mps(), + "All tensors must be on MPS device"); + + const int64_t num_seqs = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_size = query.size(2); + const int64_t max_num_blocks_per_seq = block_tables.size(1); + + // Validate configurations + TORCH_CHECK(isValidConfiguration(head_size, block_size), + "Unsupported head_size/block_size combination: ", head_size, "/", + block_size); + + // For v1, no partitioning - each sequence processed by one threadgroup + // Kernel configuration (should match the instantiated kernels) + const int num_threads = 256; + const int num_simd_lanes = 32; + const int partition_size = 0; // v1 doesn't use partitioning + + // Calculate shared memory requirements (from mistral.rs) + const int num_simds = num_threads / num_simd_lanes; + const int padded_max_context_len = + ((max_seq_len + block_size - 1) / block_size) * block_size; + const int logits_size = padded_max_context_len * sizeof(float); + const int outputs_size = (num_simds / 2) * head_size * sizeof(float); + const size_t shared_memory_size = std::max(logits_size, outputs_size); + + // Get kernel name - v1 kernels have partition_size=0 in their name + std::string kernel_name = + getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size, + block_size, num_threads, num_simd_lanes, partition_size); + + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + + // Load Metal library + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Create function constants for conditional compilation + MTLFunctionConstantValues *constants = + [[MTLFunctionConstantValues alloc] init]; + bool use_partitioning = false; + bool use_alibi = alibi_slopes.has_value(); + [constants setConstantValue:&use_partitioning + type:MTLDataTypeBool + atIndex:10]; + [constants setConstantValue:&use_alibi type:MTLDataTypeBool atIndex:20]; + [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:30]; + + NSString *kernelNameStr = + [NSString stringWithUTF8String:kernel_name.c_str()]; + id fn = [lib newFunctionWithName:kernelNameStr + constantValues:constants + error:&error]; + TORCH_CHECK( + fn, "Failed to create Metal function '", kernel_name, + "': ", error ? error.localizedDescription.UTF8String : "unknown error"); + + id pso = + [device newComputePipelineStateWithFunction:fn error:&error]; + TORCH_CHECK(pso, "Failed to create compute pipeline state: ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Setup command buffer and encoder + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer"); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute command encoder"); + + [enc setComputePipelineState:pso]; + + // Set threadgroup memory + [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0]; + + // Buffer arguments (matching the Metal kernel signature) + int buffer_idx = 0; + + // Skip exp_sums and max_logits for v1 (buffers 0, 1) + buffer_idx = 2; + + // out buffer + [enc setBuffer:getMTLBufferStorage(out) + offset:out.storage_offset() * out.element_size() + atIndex:buffer_idx++]; + + // query buffer + [enc setBuffer:getMTLBufferStorage(query) + offset:query.storage_offset() * query.element_size() + atIndex:buffer_idx++]; + + // key_cache buffer + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:buffer_idx++]; + + // value_cache buffer + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:buffer_idx++]; + + // k_scale and v_scale (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:buffer_idx++]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx += 2; // Skip k_scale and v_scale buffer slots + } + + // num_kv_heads + int32_t num_kv_heads_i32 = static_cast(num_kv_heads); + [enc setBytes:&num_kv_heads_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // scale + float scale_f32 = static_cast(scale); + [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++]; + + // softcapping (default to 1.0 for no capping) + float softcapping = 1.0f; + [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++]; + + // block_tables buffer + [enc setBuffer:getMTLBufferStorage(block_tables) + offset:block_tables.storage_offset() * block_tables.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [enc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_blocks_per_seq + int32_t max_num_blocks_per_seq_i32 = + static_cast(max_num_blocks_per_seq); + [enc setBytes:&max_num_blocks_per_seq_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // alibi_slopes (optional) + if (use_alibi) { + [enc setBuffer:getMTLBufferStorage(alibi_slopes.value()) + offset:alibi_slopes.value().storage_offset() * + alibi_slopes.value().element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx++; // Skip this buffer slot + } + + // Stride parameters + int32_t q_stride = static_cast(query.stride(0)); + int32_t kv_block_stride = static_cast(key_cache.stride(0)); + int32_t kv_head_stride = static_cast(key_cache.stride(1)); + + [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++]; + [enc setBytes:&kv_block_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + [enc setBytes:&kv_head_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch configuration + // Grid: (num_heads, num_seqs, 1) - no partitioning for v1 + MTLSize grid = MTLSizeMake(num_heads, num_seqs, 1); + MTLSize threadgroup = MTLSizeMake(num_threads, 1, 1); + + [enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup]; + [enc endEncoding]; + + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } +} + +void paged_attention_v2( + torch::Tensor &out, // [num_seqs, num_heads, head_size] + torch::Tensor &exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor &max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor + &tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor &query, // [num_seqs, num_heads, head_size] + torch::Tensor + &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor + &value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional &alibi_slopes, + const std::string &kv_cache_dtype, torch::Tensor &k_scale, + torch::Tensor &v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + // TODO: support blocksparse. + // Validate block sparse is not supported yet + TORCH_CHECK( + !is_block_sparse, + "Block sparse attention is not yet supported in Metal implementation"); + + // Determine cache dtype based on kv_cache_dtype + torch::ScalarType cache_dtype = key_cache.scalar_type(); + bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3"); + if (use_fp8_scales) { + TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type"); + TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars"); + TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32, + "FP8 scales must be float32"); + } + + // Validate input tensors + TORCH_CHECK(out.device().is_mps() && query.device().is_mps() && + key_cache.device().is_mps() && + value_cache.device().is_mps() && exp_sums.device().is_mps() && + max_logits.device().is_mps() && tmp_out.device().is_mps() && + block_tables.device().is_mps() && seq_lens.device().is_mps(), + "All tensors must be on MPS device"); + + const int64_t num_seqs = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_size = query.size(2); + const int64_t max_num_blocks_per_seq = block_tables.size(1); + const int64_t max_num_partitions = exp_sums.size(2); + + // Validate configurations + TORCH_CHECK(isValidConfiguration(head_size, block_size), + "Unsupported head_size/block_size combination: ", head_size, "/", + block_size); + + // For v2, use partitioning (matching the instantiated kernels) + const int num_threads = 256; + const int num_simd_lanes = 32; + const int partition_size = 512; // v2 uses partitioning + + // Calculate shared memory requirements (from mistral.rs) + const int num_simds = num_threads / num_simd_lanes; + const int logits_size = partition_size * sizeof(float); + const int outputs_size = (num_simds / 2) * head_size * sizeof(float); + const size_t shared_memory_size = std::max(logits_size, outputs_size); + + // Get kernel names + std::string kernel_name = + getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size, + block_size, num_threads, num_simd_lanes, partition_size); + // Reduce kernel doesn't have block_size in its name + std::string reduce_kernel_name = "paged_attention_v2_reduce"; + switch (query.scalar_type()) { + case torch::kFloat: + reduce_kernel_name += "_float"; + break; + case torch::kHalf: + reduce_kernel_name += "_half"; + break; + case torch::kBFloat16: + reduce_kernel_name += "_bfloat16_t"; + break; + default: + TORCH_CHECK(false, + "Unsupported dtype for paged attention: ", query.scalar_type()); + } + reduce_kernel_name += "_hs" + std::to_string(head_size) + "_nt" + + std::to_string(num_threads) + "_nsl" + + std::to_string(num_simd_lanes) + "_ps" + + std::to_string(partition_size); + + @autoreleasepool { + id device = MTLCreateSystemDefaultDevice(); + + // Load Metal library + std::string moduleDir = getModuleDirectory(); + std::string metallibPath = moduleDir + "/" + METALLIB_PATH; + NSString *metallibPathStr = + [NSString stringWithUTF8String:metallibPath.c_str()]; + NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr]; + NSError *error = nil; + id lib = [device newLibraryWithURL:metallibURL error:&error]; + TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ", + error ? error.localizedDescription.UTF8String + : "unknown error"); + + // Setup command buffer and queue + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get current MPS stream"); + + id cmdBuf = stream->commandBuffer(); + TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer"); + + dispatch_queue_t q = stream->queue(); + dispatch_sync(q, ^{ + // ================================================================== + // Phase 1: Main paged attention kernel with partitioning + // ================================================================== + + // Create function constants for main kernel + MTLFunctionConstantValues *mainConstants = + [[MTLFunctionConstantValues alloc] init]; + bool use_partitioning = true; + bool use_alibi = alibi_slopes.has_value(); + [mainConstants setConstantValue:&use_partitioning + type:MTLDataTypeBool + atIndex:10]; + [mainConstants setConstantValue:&use_alibi + type:MTLDataTypeBool + atIndex:20]; + [mainConstants setConstantValue:&use_fp8_scales + type:MTLDataTypeBool + atIndex:30]; + + NSString *kernelNameStr = + [NSString stringWithUTF8String:kernel_name.c_str()]; + NSError *mainError = nil; + id mainFn = [lib newFunctionWithName:kernelNameStr + constantValues:mainConstants + error:&mainError]; + TORCH_CHECK(mainFn, "Failed to create Metal function '", kernel_name, + "': ", + mainError ? mainError.localizedDescription.UTF8String + : "unknown error"); + + NSError *psoError = nil; + id mainPso = + [device newComputePipelineStateWithFunction:mainFn error:&psoError]; + TORCH_CHECK(mainPso, "Failed to create compute pipeline state: ", + psoError ? psoError.localizedDescription.UTF8String + : "unknown error"); + + id enc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(enc, "Failed to create compute command encoder"); + + [enc setComputePipelineState:mainPso]; + [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0]; + + // Set buffers for main kernel + int buffer_idx = 0; + + // exp_sums buffer + [enc setBuffer:getMTLBufferStorage(exp_sums) + offset:exp_sums.storage_offset() * exp_sums.element_size() + atIndex:buffer_idx++]; + + // max_logits buffer + [enc setBuffer:getMTLBufferStorage(max_logits) + offset:max_logits.storage_offset() * max_logits.element_size() + atIndex:buffer_idx++]; + + // tmp_out buffer + [enc setBuffer:getMTLBufferStorage(tmp_out) + offset:tmp_out.storage_offset() * tmp_out.element_size() + atIndex:buffer_idx++]; + + // query buffer + [enc setBuffer:getMTLBufferStorage(query) + offset:query.storage_offset() * query.element_size() + atIndex:buffer_idx++]; + + // key_cache buffer + [enc setBuffer:getMTLBufferStorage(key_cache) + offset:key_cache.storage_offset() * key_cache.element_size() + atIndex:buffer_idx++]; + + // value_cache buffer + [enc setBuffer:getMTLBufferStorage(value_cache) + offset:value_cache.storage_offset() * value_cache.element_size() + atIndex:buffer_idx++]; + + // k_scale and v_scale (for FP8) + if (use_fp8_scales) { + [enc setBuffer:getMTLBufferStorage(k_scale) + offset:k_scale.storage_offset() * k_scale.element_size() + atIndex:buffer_idx++]; + [enc setBuffer:getMTLBufferStorage(v_scale) + offset:v_scale.storage_offset() * v_scale.element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx += 2; // Skip k_scale and v_scale buffer slots + } + + // num_kv_heads + int32_t num_kv_heads_i32 = static_cast(num_kv_heads); + [enc setBytes:&num_kv_heads_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // scale + float scale_f32 = static_cast(scale); + [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++]; + + // softcapping (default to 1.0 for no capping) + float softcapping = 1.0f; + [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++]; + + // block_tables buffer + [enc setBuffer:getMTLBufferStorage(block_tables) + offset:block_tables.storage_offset() * block_tables.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [enc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_blocks_per_seq + int32_t max_num_blocks_per_seq_i32 = + static_cast(max_num_blocks_per_seq); + [enc setBytes:&max_num_blocks_per_seq_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // alibi_slopes (optional) + if (use_alibi) { + [enc setBuffer:getMTLBufferStorage(alibi_slopes.value()) + offset:alibi_slopes.value().storage_offset() * + alibi_slopes.value().element_size() + atIndex:buffer_idx++]; + } else { + buffer_idx++; // Skip this buffer slot + } + + // Stride parameters + int32_t q_stride = static_cast(query.stride(0)); + int32_t kv_block_stride = static_cast(key_cache.stride(0)); + int32_t kv_head_stride = static_cast(key_cache.stride(1)); + + [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++]; + [enc setBytes:&kv_block_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + [enc setBytes:&kv_head_stride + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch main kernel + // Grid: (num_heads, num_seqs, max_num_partitions) - with partitioning for + // v2 + MTLSize mainGrid = MTLSizeMake(num_heads, num_seqs, max_num_partitions); + MTLSize mainThreadgroup = MTLSizeMake(num_threads, 1, 1); + + [enc dispatchThreadgroups:mainGrid threadsPerThreadgroup:mainThreadgroup]; + [enc endEncoding]; + + // ================================================================== + // Phase 2: Reduction kernel to combine partitions + // ================================================================== + + // Create reduction kernel + NSString *reduceKernelNameStr = + [NSString stringWithUTF8String:reduce_kernel_name.c_str()]; + id reduceFn = [lib newFunctionWithName:reduceKernelNameStr]; + TORCH_CHECK(reduceFn, "Failed to create Metal function '", + reduce_kernel_name, "'"); + + NSError *reducePsoError = nil; + id reducePso = + [device newComputePipelineStateWithFunction:reduceFn + error:&reducePsoError]; + TORCH_CHECK( + reducePso, "Failed to create compute pipeline state for reduction: ", + reducePsoError ? reducePsoError.localizedDescription.UTF8String + : "unknown error"); + + // Calculate shared memory for reduction kernel + size_t reduce_shared_memory_size = + max_num_partitions * sizeof(float) * 2; // max_logits + exp_sums + + id reduceEnc = [cmdBuf computeCommandEncoder]; + TORCH_CHECK(reduceEnc, + "Failed to create compute command encoder for reduction"); + + [reduceEnc setComputePipelineState:reducePso]; + [reduceEnc setThreadgroupMemoryLength:reduce_shared_memory_size + atIndex:0]; + + // Set buffers for reduction kernel + buffer_idx = 0; + + // out buffer (final output) + [reduceEnc setBuffer:getMTLBufferStorage(out) + offset:out.storage_offset() * out.element_size() + atIndex:buffer_idx++]; + + // exp_sums buffer + [reduceEnc setBuffer:getMTLBufferStorage(exp_sums) + offset:exp_sums.storage_offset() * exp_sums.element_size() + atIndex:buffer_idx++]; + + // max_logits buffer + [reduceEnc + setBuffer:getMTLBufferStorage(max_logits) + offset:max_logits.storage_offset() * max_logits.element_size() + atIndex:buffer_idx++]; + + // tmp_out buffer + [reduceEnc setBuffer:getMTLBufferStorage(tmp_out) + offset:tmp_out.storage_offset() * tmp_out.element_size() + atIndex:buffer_idx++]; + + // seq_lens buffer (context_lens in kernel) + [reduceEnc setBuffer:getMTLBufferStorage(seq_lens) + offset:seq_lens.storage_offset() * seq_lens.element_size() + atIndex:buffer_idx++]; + + // max_num_partitions + int32_t max_num_partitions_i32 = static_cast(max_num_partitions); + [reduceEnc setBytes:&max_num_partitions_i32 + length:sizeof(int32_t) + atIndex:buffer_idx++]; + + // Dispatch reduction kernel + // Grid: (num_heads, num_seqs) - one threadgroup per sequence/head + // combination + MTLSize reduceGrid = MTLSizeMake(num_heads, num_seqs, 1); + MTLSize reduceThreadgroup = MTLSizeMake(num_threads, 1, 1); + + [reduceEnc dispatchThreadgroups:reduceGrid + threadsPerThreadgroup:reduceThreadgroup]; + [reduceEnc endEncoding]; + + stream->synchronize(at::mps::SyncType::COMMIT); + }); + } +} \ No newline at end of file diff --git a/paged-attention-metal/utils.metal b/paged-attention-metal/utils.metal new file mode 100644 index 0000000000000000000000000000000000000000..d3b638aa37d7f8af3e15ee8060a3bc226fed7b0d --- /dev/null +++ b/paged-attention-metal/utils.metal @@ -0,0 +1,246 @@ +#include +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template >::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template >::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template >::type> + constexpr METAL_FUNC operator T() constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +#endif diff --git a/paged-attention/attention/attention_dtypes.h b/paged-attention/attention/attention_dtypes.h new file mode 100644 index 0000000000000000000000000000000000000000..64f86381d9db902a6ff04ebe9520d332d40ff1ff --- /dev/null +++ b/paged-attention/attention/attention_dtypes.h @@ -0,0 +1,7 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/paged-attention/attention/attention_generic.cuh b/paged-attention/attention/attention_generic.cuh new file mode 100644 index 0000000000000000000000000000000000000000..62409c0cce93e696cebcb69cb7b34526d6b26a47 --- /dev/null +++ b/paged-attention/attention/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/paged-attention/attention/attention_kernels.cuh b/paged-attention/attention/attention_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8f24be89578b87c2451d8f3759dc15f3385b4e90 --- /dev/null +++ b/paged-attention/attention/attention_kernels.cuh @@ -0,0 +1,670 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" +#include "cuda_compat.h" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, *k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + *v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/paged-attention/attention/attention_utils.cuh b/paged-attention/attention/attention_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..826b0edffae67f772828aefcd44f8a073bf892b9 --- /dev/null +++ b/paged-attention/attention/attention_utils.cuh @@ -0,0 +1,57 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../cuda_compat.h" +#include "attention_dtypes.h" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += VLLM_SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/paged-attention/attention/dtype_bfloat16.cuh b/paged-attention/attention/dtype_bfloat16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..97a25baa1fc0de977f3068a7a6a901d27fcfa6ad --- /dev/null +++ b/paged-attention/attention/dtype_bfloat16.cuh @@ -0,0 +1,463 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template <> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template <> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template <> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template <> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template <> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template <> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + +} // namespace vllm diff --git a/paged-attention/attention/dtype_float16.cuh b/paged-attention/attention/dtype_float16.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a1815f0ed4fc4706840d0136abfe7f96b6fd48a --- /dev/null +++ b/paged-attention/attention/dtype_float16.cuh @@ -0,0 +1,504 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifdef USE_ROCM + #include +#endif + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif +} + +inline __device__ float half_to_float(uint16_t h) { + float f; +#ifndef USE_ROCM + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif +#else + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#ifndef USE_ROCM + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); +#endif + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +} // namespace vllm diff --git a/paged-attention/attention/dtype_float32.cuh b/paged-attention/attention/dtype_float32.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7c6a686db3ba94f114bb965b6a7c94c6a71ecdb7 --- /dev/null +++ b/paged-attention/attention/dtype_float32.cuh @@ -0,0 +1,251 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { return a * b; } + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { dst = src; } + +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + +// From float to float. +inline __device__ float to_float(float u) { return u; } + +inline __device__ float2 to_float(float2 u) { return u; } + +inline __device__ float4 to_float(float4 u) { return u; } + +inline __device__ Float4_ to_float(Float4_ u) { return u; } + +inline __device__ Float8_ to_float(Float8_ u) { return u; } + +// Zero-out a variable. +inline __device__ void zero(float& dst) { dst = 0.f; } + +} // namespace vllm diff --git a/paged-attention/attention/dtype_fp8.cuh b/paged-attention/attention/dtype_fp8.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e714e321b0beb2bd4b03bdabbdcd118502ccea46 --- /dev/null +++ b/paged-attention/attention/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/paged-attention/attention/paged_attention_v1.cu b/paged-attention/attention/paged_attention_v1.cu new file mode 100644 index 0000000000000000000000000000000000000000..7a5ef10f8ef3b7c3ea9c8d014a333717904838f3 --- /dev/null +++ b/paged-attention/attention/paged_attention_v1.cu @@ -0,0 +1,187 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" +#include "cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V1(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/paged-attention/attention/paged_attention_v2.cu b/paged-attention/attention/paged_attention_v2.cu new file mode 100644 index 0000000000000000000000000000000000000000..b45b28dad05eab20e8e593a9d52d44a5b89778a5 --- /dev/null +++ b/paged-attention/attention/paged_attention_v2.cu @@ -0,0 +1,197 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "attention_kernels.cuh" +#include "cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 32: + LAUNCH_PAGED_ATTENTION_V2(32); + break; + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/paged-attention/cache_kernels.cu b/paged-attention/cache_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..88559c8fe718377249f0230f7d940f807ec5303c --- /dev/null +++ b/paged-attention/cache_kernels.cu @@ -0,0 +1,734 @@ +#include +#include +#include + +#include "cuda_utils.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#include +#include +#include +#include + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "Invalid device combination"); + } + + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + const int64_t block_size_in_bytes = src.element_size() * src.stride(0); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); + } +} + +namespace vllm { + +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +// Kernel for MLA, which works on a single joint kv_cache +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_mla_kernel( + int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping, + const int mem_footprint_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + scalar_t* cache = reinterpret_cast(cache_ptrs[layer_idx]); + int64_t src_block = block_mapping[2 * pair_idx]; + int64_t dst_block = block_mapping[2 * pair_idx + 1]; + int64_t src_offset = src_block * mem_footprint_per_block; + int64_t dst_offset = dst_block * mem_footprint_per_block; + for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) { + cache[dst_offset + i] = cache[src_offset + i]; + } +} + +} // namespace vllm + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + torch::Device cache_device = key_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda()); + + // Create data structures for the kernel. + // Create an array of pointers to the key and value caches. + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); + } + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); + + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + + // Launch the kernel. + const int numel_per_block = key_caches[0][0].numel(); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); +} + +// copy blocks kernel for MLA (assumes a joint KV-cache) +void copy_blocks_mla(std::vector const& kv_caches, + const torch::Tensor& block_mapping) { + int num_layers = kv_caches.size(); + if (num_layers == 0) { + return; + } + torch::Device cache_device = kv_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA"); + + std::vector cache_ptrs(num_layers); + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + cache_ptrs[layer_idx] = + reinterpret_cast(kv_caches[layer_idx].data_ptr()); + } + torch::Tensor cache_ptrs_tensor = + torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64) + .to(cache_device); + + int num_pairs = block_mapping.size(0); + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + int mem_footprint_per_block = kv_caches[0].stride(0); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, mem_footprint_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] { + vllm::copy_blocks_mla_kernel<<>>( + cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), mem_footprint_per_block); + })); +} + +namespace vllm { + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, *k_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, *v_scale); + } + } +} + +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, + // head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int64_t block_stride, const int64_t page_stride, + const int64_t head_stride, const int64_t key_stride, + const int64_t value_stride, const int num_heads, const int head_size, + const int block_size, const float* k_scale, const float* v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * page_stride + + head_idx * head_stride + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, *k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, *v_scale); + } + } +} + +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + +} // namespace vllm + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + int num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) +} + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, page_stride, \ + head_stride, key_stride, value_stride, num_heads, head_size, \ + block_size, reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); + +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int64_t block_stride = key_cache.stride(0); + int64_t page_stride = key_cache.stride(1); + int64_t head_stride = key_cache.stride(2); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); +} + +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); +} + +namespace vllm { + +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float scale, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], scale); + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); + +// Only for testing. +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + } +} + +namespace vllm { + +// grid is launched with dimensions (batch, num_splits) +template +__global__ void gather_cache( + const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, + // ENTRIES...] + scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...] + const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] + const int32_t* __restrict__ cu_seq_lens, // [BATCH+1] + const int32_t block_size, const int32_t entry_size, + const int64_t block_table_stride, const int64_t cache_block_stride, + const int64_t cache_entry_stride, const int64_t dst_entry_stride, + const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per + // batch + + const int64_t bid = blockIdx.x; // Batch ID + const int32_t num_splits = gridDim.y; + const int32_t split = blockIdx.y; + const int32_t seq_start = cu_seq_lens[bid]; + const int32_t seq_end = cu_seq_lens[bid + 1]; + const int32_t seq_len = seq_end - seq_start; + const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size); + const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits); + + const int32_t split_start = split * split_blocks; + const int32_t split_end = min((split + 1) * split_blocks, tot_blocks); + + const bool is_active_split = (split_start < tot_blocks); + const bool is_last_split = (split_end == tot_blocks); + + if (!is_active_split) return; + + int32_t full_blocks_end = split_end; + int32_t partial_block_size = 0; + + // Adjust the pointer for the block_table for this batch. + // If seq_starts is provided, compute an offset based on (seq_starts[bid] / + // page_size) + const int32_t batch_offset = bid * block_table_stride; + int32_t offset = 0; + if (seq_starts != nullptr) { + offset = seq_starts[bid] / block_size; + } + const int32_t* batch_block_table = block_table + batch_offset + offset; + + // Adjust dst pointer based on the cumulative sequence lengths. + dst += seq_start * dst_entry_stride; + + if (is_last_split) { + partial_block_size = seq_len % block_size; + if (partial_block_size) full_blocks_end -= 1; + } + + auto copy_entry = [&](const scalar_t* __restrict__ _src, + scalar_t* __restrict__ _dst) { + for (int i = threadIdx.x; i < entry_size; i += blockDim.x) + _dst[i] = _src[i]; + }; + + for (int pid = split_start; pid < full_blocks_end; ++pid) { + auto block_id = batch_block_table[pid]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + pid * block_size * dst_entry_stride; + for (int eid = 0; eid < block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } + + if (partial_block_size) { + auto block_id = batch_block_table[full_blocks_end]; + auto block_start_ptr = src_cache + block_id * cache_block_stride; + auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride; + for (int eid = 0; eid < partial_block_size; ++eid) { + copy_entry(block_start_ptr + eid * cache_entry_stride, + block_dst_ptr + eid * dst_entry_stride); + } + } +} + +} // namespace vllm + +// Macro to dispatch the kernel based on the data type. +#define CALL_GATHER_CACHE(CPY_DTYPE) \ + vllm::gather_cache<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + block_size, entry_size, block_table_stride, cache_block_stride, \ + cache_entry_stride, dst_entry_stride, seq_starts_ptr); + +// Gather sequences from the cache into the destination tensor. +// - cu_seq_lens contains the cumulative sequence lengths for each batch +// - block_table contains the cache block indices for each sequence +// - Optionally, seq_starts (if provided) offsets the starting block index by +// (seq_starts[bid] / page_size) +void gather_cache( + torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...] + torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] + torch::Tensor const& cu_seq_lens, // [BATCH+1] + int64_t batch_size, + std::optional seq_starts = std::nullopt) { + at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int32_t block_size = src_cache.size(1); + int32_t entry_size = src_cache.flatten(2, -1).size(2); + + TORCH_CHECK(block_table.dtype() == torch::kInt32, + "block_table must be int32"); + TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32, + "cu_seq_lens must be int32"); + if (seq_starts.has_value()) { + TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, + "seq_starts must be int32"); + } + + TORCH_CHECK(src_cache.device() == dst.device(), + "src_cache and dst must be on the same device"); + TORCH_CHECK(src_cache.device() == block_table.device(), + "src_cache and block_table must be on the same device"); + TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), + "src_cache and cu_seq_lens must be on the same device"); + if (seq_starts.has_value()) { + TORCH_CHECK(src_cache.device() == seq_starts.value().device(), + "src_cache and seq_starts must be on the same device"); + } + + int64_t block_table_stride = block_table.stride(0); + int64_t cache_block_stride = src_cache.stride(0); + int64_t cache_entry_stride = src_cache.stride(1); + int64_t dst_entry_stride = dst.stride(0); + + // Decide on the number of splits based on the batch size. + int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; + dim3 grid(batch_size, num_splits); + dim3 block(1024); + + TORCH_CHECK(src_cache.dtype() == dst.dtype(), + "src_cache and dst must have the same dtype"); + + const int dtype_bits = src_cache.element_size() * 8; + const int32_t* seq_starts_ptr = + seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr; + + if (dtype_bits == 32) { + CALL_GATHER_CACHE(uint32_t); + } else if (dtype_bits == 16) { + CALL_GATHER_CACHE(uint16_t); + } else if (dtype_bits == 8) { + CALL_GATHER_CACHE(uint8_t); + } else { + TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); + } +} diff --git a/paged-attention/cuda_compat.h b/paged-attention/cuda_compat.h new file mode 100644 index 0000000000000000000000000000000000000000..affa051c759512f2816c51ce25e35ee80f960f5e --- /dev/null +++ b/paged-attention/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 +#else + #define WARP_SIZE 32 +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/paged-attention/dispatch_utils.h b/paged-attention/dispatch_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f7b75c48373f68e9025020eea507415fb9405e2e --- /dev/null +++ b/paged-attention/dispatch_utils.h @@ -0,0 +1,83 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#pragma once + +#include + +// Need a special dispatch case macro since we will nest the FP8 dispatch. +// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. +#define AT_DISPATCH_FP8_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +// ROCm devices might use either fn or fnuz, so set up dispatch table for both. +// A host-based check at runtime will create a preferred FP8 type for ROCm +// such that the correct kernel is dispatched. +#ifdef USE_ROCM + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) + + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif + +// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. +// See AT_DISPATCH_FP8_CASE above. +#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/paged-attention/quantization/fp8/amd/quant_utils.cuh b/paged-attention/quantization/fp8/amd/quant_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e51a4e14e518f83e11f4c56b54bbb046ff28c77d --- /dev/null +++ b/paged-attention/quantization/fp8/amd/quant_utils.cuh @@ -0,0 +1,671 @@ +#pragma once +#include + +#include +#include +#include + +#include "../../../attention/attention_dtypes.h" + +namespace vllm { +#ifdef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 + +// Use hardware cvt instruction for fp8 on rocm +template +__device__ __forceinline__ fp8_type cvt_c10(float const r) { + return {}; +} + +// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro +// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes +// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES +// on ROCm instantiates both OCP and FNUZ kernels, we need to replace +// the new HW cvt with something reasonable that doesn't rely on the +// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer. +template <> +__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) { + #if HIP_FP8_TYPE_OCP + return c10::Float8_e4m3fn( + __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation, + __hip_fp8_e4m3::__default_interpret), + c10::Float8_e4m3fn::from_bits()); + #else + // Cast implemented by pytorch. Uses bit manipulation instead of HW cvt. + // HW cvt above is faster when it is available (ROCm 6.3 or newer). + return static_cast(r); + #endif +} + +template <> +__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) { + return c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation, + __hip_fp8_e4m3_fnuz::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +} + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; +} + +template +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; +} + + #if HIP_FP8_TYPE_OCP +using fp8_type = __hip_fp8_e4m3; +using fp8x2_type = __hip_fp8x2_e4m3; + #else +using fp8_type = __hip_fp8_e4m3_fnuz; +using fp8x2_type = __hip_fp8x2_e4m3_fnuz; + #endif + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + return tmp.ui32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8)); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) { + fp8_type f8; + f8.__x = a; + return static_cast(f8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint32_t& a) { + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + return __hip_cvt_float_to_fp8(__bfloat162float(a), + fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) { + return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// float2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +// Float4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +// Float4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +// Float8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +// float2 -> bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +// Float4 -> bfloat162x2 +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +// Float8 -> bfloat162x4 +template <> +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains + + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP + + */ + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + fp8_type f8; + f8.__x = a; + return __float2bfloat16(static_cast(f8) * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, float scale) { + fp8_type f8; + f8.__x = a; + return static_cast(f8) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, float scale) { + fp8x2_type f8x2; + f8x2.__x = a; + return static_cast(f8x2) * scale; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + __half_raw res; + res.data = scaled_vec_conversion(a, scale); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); + tmp.h2r.x.data *= scale; + tmp.h2r.y.data *= scale; + return tmp.ui32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half_raw tmp; + tmp.x = a; + tmp.data /= scale; + return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// halfx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + tmp.h2r.x.data /= scale; + tmp.h2r.y.data /= scale; + return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, float scale) { + return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, + fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, float scale) { + return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// floatx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, + fp8_type::__default_interpret); +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; +} + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/paged-attention/quantization/fp8/nvidia/quant_utils.cuh b/paged-attention/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f8cd1dcba4ab337703885bc0a3d577c8305777a6 --- /dev/null +++ b/paged-attention/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,573 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 + + #if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; + #endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} + #endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); + #else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; + #endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t& a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } + #endif + assert(false); + __builtin_unreachable(); // Suppress missing return statement warning +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } + #endif + assert(false); + __builtin_unreachable(); // Suppress missing return statement warning +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/kernels/allclose_default.py b/tests/kernels/allclose_default.py new file mode 100644 index 0000000000000000000000000000000000000000..80eb1eeb9fb738d70efe28d64df98b2ff7223463 --- /dev/null +++ b/tests/kernels/allclose_default.py @@ -0,0 +1,14 @@ +import torch + +# Reference default values of atol and rtol are from +# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 +default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} + + +def get_default_atol(output) -> float: + return default_atol[output.dtype] + + +def get_default_rtol(output) -> float: + return default_rtol[output.dtype] diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4f8eff8522b65544ed85459f856fcd6edc95c6 --- /dev/null +++ b/tests/kernels/conftest.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Tuple, Union + +import paged_attention as ops +import pytest +import torch + + +@pytest.fixture() +def kv_cache_factory(): + return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flash + + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from paged_attention.platforms import current_platform + + current_platform.seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty( + size=value_cache_shape, dtype=torch_dtype, device=device + ) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + from paged_attention.platforms import current_platform + + current_platform.seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + + key_caches: List[torch.Tensor] = [] + value_caches: List[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty( + size=key_value_cache_shape, dtype=torch_dtype, device=device + ) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, +) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + # -----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..953d8677c7643449df27aa7b13fee06a68891759 --- /dev/null +++ b/tests/kernels/test_attention.py @@ -0,0 +1,439 @@ +import random +from typing import List, Optional, Tuple + +import paged_attention as ops +import pytest +import torch +from paged_attention.platforms import current_platform + +from .allclose_default import get_default_atol, get_default_rtol +from .utils import get_max_shared_memory_bytes, opcheck + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} +DTYPES = ( + [torch.half, torch.bfloat16, torch.float] + if not current_platform.is_rocm() + else [torch.half, torch.bfloat16] +) +NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +# This should be sync with get_supported_head_sizes() in +# vllm.attention.ops.paged_attn.PagedAttention +HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] + +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +if current_platform.is_mps(): + KV_CACHE_DTYPE = ["auto", "fp8"] +else: + KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +if current_platform.is_mps(): + DEVICES = ["mps:0"] +else: + DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables_lst = block_tables.cpu().tolist() + seq_lens_lst = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables_lst[i] + seq_len = int(seq_lens_lst[i]) + + keys_lst: List[torch.Tensor] = [] + values_lst: List[torch.Tensor] = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, :, :, block_offset] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize( + "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] +) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + if (kv_cache_dtype == "fp8" and head_size % 16) or ( + version == "rocm" and head_size not in (64, 128) + ): + pytest.skip() + + current_platform.seed_everything(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: List[List[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck( + ops.ops.paged_attention_v1, + ( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=( + head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0] + and not device.startswith("mps") + ), + ) + + elif version in ("v2", "rocm"): + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck( + ops.ops.paged_attention_v2, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=( + head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0] + and not device.startswith("mps") + ), + ) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck( + torch.ops._rocm_C.paged_attention, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=( + head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0] + and not device.startswith("mps") + ), + ) + + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + dequantized_key_cache = torch.empty( + size=key_cache_shape, dtype=dtype, device=device + ) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty( + size=value_cache_shape, dtype=dtype, device=device + ) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + # NOTE: bfloat16 with ALiBi can have slightly higher precision differences + elif dtype == torch.bfloat16 and use_alibi: + atol, rtol = 2e-3, 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs: List[torch.Tensor] = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + + return torch.cat(ref_outputs, dim=0) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6410bc76f3d3beeb4366916a33870efceb7d45 --- /dev/null +++ b/tests/kernels/test_cache.py @@ -0,0 +1,500 @@ +import random +from typing import List, Tuple + +import paged_attention as ops +import pytest +import torch +from paged_attention.platforms import current_platform + +from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck + +COPYING_DIRECTION = [("gpu", "cpu"), ("gpu", "gpu"), ("cpu", "gpu")] +DTYPES = [torch.half, torch.bfloat16, torch.float] +NUM_TOKENS = [42] # Arbitrary values for testing +NUM_LAYERS = [1] # Arbitrary values for testing +NUM_HEADS = [8] # Arbitrary values for testing +HEAD_SIZES = [64, 80, 120, 256] +BLOCK_SIZES = [8, 16, 32] + +# Arbitrary values for testing +# don't make it too large. e.g. [1024, 36000] will OOM +NUM_BLOCKS = [1024, 10000] + +NUM_MAPPINGS = [256] # Arbitrary values for testing +SEEDS = [0] +if current_platform.is_mps(): + DEVICES = ["mps:0"] +else: + DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + +if current_platform.is_mps(): + KV_CACHE_DTYPE = ["auto", "fp8"] +else: + KV_CACHE_DTYPE = ["auto", "fp8"] + + +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_copy_blocks( + kv_cache_factory, + num_mappings: int, + num_layers: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + kv_cache_dtype: str, + device: str, +) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() + current_platform.seed_everything(seed) + # Don't set MPS as default device to avoid placeholder storage error + if not device.startswith("mps"): + torch.set_default_device(device) + # Generate random block mappings where each source block is mapped to two + # destination blocks. + assert 2 * num_mappings <= num_blocks + src_blocks = random.sample(range(num_blocks), num_mappings) + remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) + block_mapping: List[Tuple[int, int]] = [] + for i in range(num_mappings): + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) + + # Clone the KV caches. + cloned_key_caches = [key_cache.clone() for key_cache in key_caches] + cloned_value_caches = [value_cache.clone() for value_cache in value_caches] + + # Call the copy blocks kernel. + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) + + opcheck( + ops.ops.copy_blocks, + (key_caches, value_caches, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + cond=(head_size == HEAD_SIZES[0]), + ) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) + + # Run the reference implementation. + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) + + # Compare the results. + for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): + torch.testing.assert_close(key_cache, cloned_key_cache) + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): + torch.testing.assert_close(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() + current_platform.seed_everything(seed) + # Don't set MPS as default device to avoid placeholder storage error + if not device.startswith("mps"): + torch.set_default_device(device) + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Using default kv_scale + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + # Call the reshape_and_cache kernel. + opcheck( + ops.ops.reshape_and_cache, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(result_key_cache, key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(result_value_cache, value_cache) + + # Run the reference implementation. + reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) + block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_indicies_lst = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets_lst = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies_lst[i] + block_offset = block_offsets_lst[i] + cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] + cloned_value_cache[block_idx, :, :, block_offset] = value[i] + + if kv_cache_dtype == "fp8": + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 + ) + else: + torch.testing.assert_close(key_cache, cloned_key_cache) + torch.testing.assert_close(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + # Flash variant doesn't support FP8 on MPS devices yet + if current_platform.is_mps() and kv_cache_dtype == "fp8": + pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS") + current_platform.seed_everything(seed) + # Don't set MPS as default device to avoid placeholder storage error + if not device.startswith("mps"): + torch.set_default_device(device) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() + del key_caches + del value_caches + + k_scale = (key.amax() / 256.0).to(torch.float32) + v_scale = (value.amax() / 256.0).to(torch.float32) + + # Clone the KV caches. + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + opcheck( + ops.ops.reshape_and_cache_flash, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + ops.convert_fp8( + result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype + ) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + ops.convert_fp8( + result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype + ) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_indicies_lst = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets_lst = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies_lst[i] + block_offset = block_offsets_lst[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + if kv_cache_dtype == "fp8": + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 + ) + else: + torch.testing.assert_close(key_cache, cloned_key_cache) + torch.testing.assert_close(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("direction", COPYING_DIRECTION) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_swap_blocks( + kv_cache_factory, + direction: Tuple[str, str], + num_mappings: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8" and "cpu" in direction: + pytest.skip() + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() + + current_platform.seed_everything(seed) + + src_device = device if direction[0] == "gpu" else "cpu" + dst_device = device if direction[1] == "gpu" else "cpu" + + src_blocks = random.sample(range(num_blocks), num_mappings) + # For the same device, mapping must not overlap + if src_device == dst_device: + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining_blocks, num_mappings) + else: + dst_blocks = random.sample(range(num_blocks), num_mappings) + + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) + + # Create the KV caches on the first device. + src_key_caches, src_value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + src_device, + ) + + # Create the KV caches on the second device. + dist_key_caches, dist_value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + dst_device, + ) + + src_key_caches_clone = src_key_caches[0].clone() + src_value_caches_clone = src_value_caches[0].clone() + + # Call the swap_blocks kernel. + do_opcheck = head_size == HEAD_SIZES[0] + opcheck( + ops.ops.swap_blocks, + (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + opcheck( + ops.ops.swap_blocks, + (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) + + for src, dst in block_mapping: + torch.testing.assert_close( + src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() + ) + torch.testing.assert_close( + src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() + ) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_fp8_e4m3_conversion( + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + + low = -224.0 + high = 224.0 + shape = (num_blocks, num_heads, head_size, block_size) + cache = torch.empty(shape, dtype=dtype, device=device) + cache.uniform_(low, high) + + cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) + ops.convert_fp8(cache_fp8, cache) + + converted_cache = torch.empty_like(cache) + ops.convert_fp8(converted_cache, cache_fp8) + + torch.testing.assert_close(cache, converted_cache, atol=0.02, rtol=0.2) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..312a41f6e1d87d7d3ec6745d15b39c114e647a6e --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,99 @@ +"""Kernel test utils""" + +import itertools +import random +import unittest +from functools import lru_cache +from numbers import Number +from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import pytest +import torch +from torch._prims_common import TensorLikeType + +# For now, disable "test_aot_dispatch_dynamic" since there are some +# bugs related to this test in PyTorch 2.4. +DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( + "test_schema", + "test_autograd_registration", + "test_faketensor", +) + +ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( + "test_schema", + "test_autograd_registration", + "test_faketensor", + "test_aot_dispatch_dynamic", +) + + +# Copied/modified from torch._refs.__init__.py +def fp8_allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + # MPS doesn't support float64, so use float32 for comparison + if a.device.type == "mps" or b.device.type == "mps": + a_cmp = a.float() + b_cmp = b.float() + else: + a_cmp = a.double() + b_cmp = b.double() + + return bool( + torch.all( + torch.isclose( + a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ).item() + ) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + + +# A special version of op check that has a restricted default set of test_utils +# and a patched version of allclose that supports fp8 types. +def opcheck( + op: Union[ + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + torch._library.custom_ops.CustomOpDef, + ], + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, + raise_exception: bool = True, + cond: bool = True, +) -> Dict[str, str]: + with unittest.mock.patch("torch.allclose", new=fp8_allclose): + if not cond: + return {} + + return torch.library.opcheck( + op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception + ) + + +@lru_cache(maxsize=None) +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from paged_attention import ops + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) diff --git a/torch-ext/paged_attention/__init__.py b/torch-ext/paged_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de56043369487facc1f163df6bd319c9806e5ca --- /dev/null +++ b/torch-ext/paged_attention/__init__.py @@ -0,0 +1,21 @@ +from ._custom_ops import ( + convert_fp8, + copy_blocks, + paged_attention_v1, + paged_attention_v2, + reshape_and_cache, + reshape_and_cache_flash, + swap_blocks, +) +from ._ops import ops + +__all__ = [ + "convert_fp8", + "copy_blocks", + "ops", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "reshape_and_cache_flash", + "swap_blocks", +] diff --git a/torch-ext/paged_attention/_custom_ops.py b/torch-ext/paged_attention/_custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c0b8db085468dee5100c98d14106a9ee917bf2 --- /dev/null +++ b/torch-ext/paged_attention/_custom_ops.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch + +from ._ops import ops + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + +def copy_blocks( + key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: + ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: + ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: + ops.convert_fp8(output, input, scale, kv_dtype) + + +__all__ = [ + "convert_fp8", + "paged_attention_v1", + "paged_attention_v2", + "reshape_and_cache", + "copy_blocks", +] diff --git a/torch-ext/paged_attention/platforms.py b/torch-ext/paged_attention/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..6277d5f50ff3ddc265bb39fa1c4d17e0341b7767 --- /dev/null +++ b/torch-ext/paged_attention/platforms.py @@ -0,0 +1,92 @@ +import os +import random +from abc import ABC, abstractmethod +from functools import lru_cache, wraps +from typing import Callable, ParamSpec, TypeVar + +import numpy as np +import torch + +IS_ROCM = torch.version.hip is not None +IS_MPS = torch.backends.mps.is_available() + + +class Platform(ABC): + @classmethod + def seed_everything(cls, seed: int) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + @abstractmethod + def get_device_name(self, device_id: int = 0) -> str: ... + + @abstractmethod + def is_cuda(self) -> bool: ... + + @abstractmethod + def is_rocm(self) -> bool: ... + + @abstractmethod + def is_mps(self) -> bool: ... + + +class CudaPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_cuda(self) -> bool: + return True + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return True + + def is_mps(self) -> bool: + return False + + +class MpsPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_cuda(self) -> bool: + return False + + def is_rocm(self) -> bool: + return False + + def is_mps(self) -> bool: + return True + +current_platform = ( + RocmPlatform() if IS_ROCM else + MpsPlatform() if IS_MPS else + CudaPlatform() if torch.cuda.is_available() else + None +) diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0b1e17b499959e264020ff77aab3dbbe8017c44 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,122 @@ +#include + +#include "registration.h" + +#include "torch_binding.h" + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // Attention ops + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); +#elif defined(METAL_KERNEL) + ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1); +#endif + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); +#elif defined(METAL_KERNEL) + ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2); +#endif + + // Swap in (out) the cache blocks from src to dst. + ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); +#elif defined(METAL_KERNEL) + ops.impl("swap_blocks", torch::kMPS, swap_blocks); +#endif + + // Copy the cache blocks from src to dst. + ops.def( + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("copy_blocks", torch::kCUDA, ©_blocks); +#elif defined(METAL_KERNEL) + ops.impl("copy_blocks", torch::kMPS, copy_blocks); +#endif + + // Reshape the key and value tensors and cache them. + ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor k_scale, Tensor v_scale) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); +#elif defined(METAL_KERNEL) + ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache); +#endif + + // Reshape the key and value tensors and cache them. + ops.def( + "reshape_and_cache_flash(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor k_scale, Tensor v_scale) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); +#elif defined(METAL_KERNEL) + ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash); +#endif + + // Gets the specified device attribute. + ops.def("get_device_attribute(int attribute, int device_id) -> int"); + ops.impl("get_device_attribute", &get_device_attribute); + + // Gets the maximum shared memory per block device attribute. + ops.def( + "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); + ops.impl("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute); + + // Convert the key and value cache to fp8 data type. + ops.def( + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " + "str kv_cache_dtype) -> ()"); +#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) + ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); +#elif defined(METAL_KERNEL) + ops.impl("convert_fp8", torch::kMPS, convert_fp8); +#endif +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..4811bb5c9f9daf20e54b6847658300c4390cf427 --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping); + +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); + +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); + +int64_t get_device_attribute(int64_t attribute, int64_t device_id); + +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); + +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype);