Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,206 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import torch
@dataclass
class CacheSlot:
keys: torch.Tensor
values: torch.Tensor
def __post_init__(self) -> None:
self.max_steps = self.keys.shape[2]
self.head_dim = self.keys.shape[3]
self.flat_heads = self.keys.shape[0] * self.keys.shape[1]
device = self.keys.device
self.length = torch.zeros((), dtype=torch.long, device=device)
self.positions = torch.arange(self.max_steps, dtype=torch.long, device=device)
@classmethod
def allocate(
cls,
*,
batch_size: int,
heads: int,
max_steps: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> "CacheSlot":
keys = torch.zeros(batch_size, heads, max_steps, head_dim, device=device, dtype=dtype)
values = torch.zeros_like(keys)
return cls(keys, values)
def reset(self) -> None:
self.length.zero_()
def write_and_view(
self,
key_chunk: torch.Tensor,
value_chunk: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
step = key_chunk.shape[2]
start = self.length
indices = self.positions[:step] + start
expanded = indices.unsqueeze(0).expand(self.flat_heads, -1)
flat_keys = self.keys.view(self.flat_heads, self.max_steps, self.head_dim)
flat_values = self.values.view(self.flat_heads, self.max_steps, self.head_dim)
flat_key_chunk = key_chunk.reshape(self.flat_heads, step, self.head_dim)
flat_value_chunk = value_chunk.reshape(self.flat_heads, step, self.head_dim)
scatter_index = expanded.unsqueeze(-1).expand_as(flat_key_chunk)
flat_keys.scatter_(1, scatter_index, flat_key_chunk)
flat_values.scatter_(1, scatter_index, flat_value_chunk)
self.length.add_(step)
bool_mask = (self.positions >= self.length).view(1, 1, 1, self.max_steps)
mask_dtype = self.keys.dtype
mask_value = torch.finfo(mask_dtype).min
attn_mask = torch.zeros_like(bool_mask, dtype=mask_dtype)
attn_mask = attn_mask.masked_fill(bool_mask, mask_value)
return self.keys, self.values, attn_mask
class KVCache:
def __init__(self, slots: List[CacheSlot]) -> None:
self.slots = slots
@classmethod
def allocate(
cls,
*,
num_layers: int,
batch_size: int,
heads: int,
max_steps: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
) -> "KVCache":
slots = [
CacheSlot.allocate(
batch_size=batch_size,
heads=heads,
max_steps=max_steps,
head_dim=head_dim,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
]
return cls(slots)
def get_slot(self, index: int) -> CacheSlot:
return self.slots[index]
def reset(self) -> None:
for slot in self.slots:
slot.reset()
clear = reset
__all__ = ["CacheSlot", "KVCache"]
|