| | from typing import Tuple |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | class Cache: |
| | def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: |
| | assert embed_dim % num_heads == 0 |
| | self._n, self._cache, self._size = num_samples, None, None |
| | self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) |
| | self.reset() |
| |
|
| | @property |
| | def shape(self) -> Tuple[int, int, int, int]: |
| | n, num_heads, _, head_dim = self._cache.shape |
| | return n, num_heads, self._size, head_dim |
| |
|
| | def reset(self) -> None: |
| | self._cache = self._reset(self._n) |
| | self._size = 0 |
| |
|
| | def prune(self, mask: np.ndarray) -> None: |
| | assert mask.ndim == 1 and mask.shape[0] == self.shape[0] |
| | self._cache = self._cache[mask] |
| | self._n = self._cache.shape[0] |
| |
|
| | def get(self) -> torch.Tensor: |
| | return self._cache[:, :, :self._size, :] |
| |
|
| | def update(self, x: torch.Tensor) -> None: |
| | assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) |
| | assert self._size + x.size(2) <= self._cache.shape[2] |
| | self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2)) |
| | self._size += x.size(2) |
| |
|
| |
|
| | class KVCache: |
| | def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: |
| | self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) |
| | self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) |
| |
|
| | @property |
| | def shape(self) -> Tuple[int, int, int, int]: |
| | return self._k_cache.shape |
| |
|
| | def reset(self) -> None: |
| | self._k_cache.reset() |
| | self._v_cache.reset() |
| |
|
| | def prune(self, mask: np.ndarray) -> None: |
| | self._k_cache.prune(mask) |
| | self._v_cache.prune(mask) |
| |
|
| | def get(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | return self._k_cache.get(), self._v_cache.get() |
| |
|
| | def update(self, k: torch.Tensor, v: torch.Tensor): |
| | self._k_cache.update(k) |
| | self._v_cache.update(v) |
| |
|
| |
|
| | class KeysValues: |
| | def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: |
| | self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) |
| |
|
| | def __getitem__(self, key: int) -> KVCache: |
| | return self._keys_values[key] |
| |
|
| | def __len__(self): |
| | return len(self._keys_values) |
| |
|
| | @property |
| | def size(self): |
| | return self._keys_values[0].shape[2] |
| |
|
| | def reset(self) -> None: |
| | for kv_cache in self._keys_values: |
| | kv_cache.reset() |
| |
|
| | def prune(self, mask: np.ndarray) -> None: |
| | for kv_cache in self._keys_values: |
| | kv_cache.prune(mask) |
| |
|
| |
|
| | class AssignWithoutInplaceCheck(torch.autograd.Function): |
| | """ |
| | Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 |
| | Warning : do not use it to overwrite a slice twice. |
| | """ |
| |
|
| | @staticmethod |
| | def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]: |
| | return tuple([slice(None), ] * dim + [slice(start, stop)]) |
| |
|
| | @staticmethod |
| | def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor: |
| | ctx.dim = dim |
| | ctx.start = start |
| | ctx.stop = stop |
| | input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value |
| | return input |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]: |
| | return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None |
| |
|