File size: 4,505 Bytes
92e51ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from dataclasses import dataclass
import torch
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from torch import nn
@dataclass
class ForwardContext:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
_FORWARD_CONTEXT = ForwardContext()
def get_forward_context():
return _FORWARD_CONTEXT
def set_forward_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
):
global _FORWARD_CONTEXT
_FORWARD_CONTEXT = ForwardContext(
is_prefill,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
slot_mapping,
context_lens,
block_tables,
)
def reset_forward_context():
global _FORWARD_CONTEXT
_FORWARD_CONTEXT = ForwardContext()
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 2048
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1:
return
d_offset = 0
while d_offset < D:
cur_block_size = min(BLOCK_SIZE, D - d_offset)
key_offsets = idx * key_stride + d_offset + tl.arange(0, BLOCK_SIZE)
value_offsets = idx * value_stride + d_offset + tl.arange(0, BLOCK_SIZE)
cache_offsets = slot * D + d_offset + tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < cur_block_size
key = tl.load(key_ptr + key_offsets, mask=mask, other=0.0)
value = tl.load(value_ptr + value_offsets, mask=mask, other=0.0)
tl.store(k_cache_ptr + cache_offsets, key, mask=mask)
tl.store(v_cache_ptr + cache_offsets, value, mask=mask)
d_offset += BLOCK_SIZE
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](
key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D
)
class Attention(nn.Module):
def __init__(
self,
num_heads: int,
head_dim: int,
scale: float,
num_kv_heads: int,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_forward_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel() and context.slot_mapping is not None:
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None:
k, v = k_cache, v_cache
o = flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q=context.max_seqlen_q,
cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k,
cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
block_table=context.block_tables,
)
else:
o = flash_attn_with_kvcache(
q.unsqueeze(1),
k_cache,
v_cache,
cache_seqlens=context.context_lens,
block_table=context.block_tables,
softmax_scale=self.scale,
causal=True,
)
return o
|