ShaswatRobotics's picture
Upload 35 files
23bc32f verified
from typing import Tuple
import numpy as np
import torch
class Cache:
def __init__(self, num_samples: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
self._n, self._cache, self._size = num_samples, None, None
self._reset = lambda n: torch.empty(n, max_tokens, embed_dim, device=device) # (B, T, E)
self.reset()
@property
def shape(self) -> Tuple[int, int, int]:
n, _, embed_dim = self._cache.shape
return n, self._size, embed_dim
def reset(self) -> None:
self._cache = self._reset(self._n)
self._size = 0
def prune(self, mask: np.ndarray) -> None:
assert mask.ndim == 1 and mask.shape[0] == self.shape[0]
self._cache = self._cache[mask]
self._n = self._cache.shape[0]
def get(self) -> torch.Tensor:
return self._cache[:, :self._size, :]
def update(self, x: torch.Tensor) -> None:
assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 2)])
assert self._size + x.size(1) <= self._cache.shape[1]
self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 1, self._size, self._size + x.size(1))
self._size += x.size(1)
class KVCache:
def __init__(self, n: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
self._k_cache = Cache(n, max_tokens, embed_dim, device)
self._v_cache = Cache(n, max_tokens, embed_dim, device)
@property
def shape(self) -> Tuple[int, int, int]:
return self._k_cache.shape
def reset(self) -> None:
self._k_cache.reset()
self._v_cache.reset()
def prune(self, mask: np.ndarray) -> None:
self._k_cache.prune(mask)
self._v_cache.prune(mask)
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self._k_cache.get(), self._v_cache.get()
def update(self, k: torch.Tensor, v: torch.Tensor):
self._k_cache.update(k)
self._v_cache.update(v)
class KeysValues:
def __init__(self, n: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
self._keys_values = tuple([KVCache(n, max_tokens, embed_dim, device) for _ in range(num_layers)])
def __getitem__(self, key: int) -> KVCache:
return self._keys_values[key]
def __len__(self):
return len(self._keys_values)
@property
def size(self):
return self._keys_values[0].shape[1]
def reset(self) -> None:
for kv_cache in self._keys_values:
kv_cache.reset()
def prune(self, mask: np.ndarray) -> None:
for kv_cache in self._keys_values:
kv_cache.prune(mask)
class AssignWithoutInplaceCheck(torch.autograd.Function):
"""
Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
Warning : do not use it to overwrite a slice twice.
"""
@staticmethod
def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]:
return tuple([slice(None), ] * dim + [slice(start, stop)])
@staticmethod
def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor:
ctx.dim = dim
ctx.start = start
ctx.stop = stop
input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value
return input
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]:
return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None