File size: 6,503 Bytes
15063d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import List, Tuple, Optional, Any, Dict
import torch
class FgateDynamicCache:
"""
A cache that grows dynamically as more tokens are generated.
Custom cache for Forgetting Transformer that does not inherit from transformers.Cache.
"""
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.log_fgate_cache: List[torch.Tensor] = []
self.key_shift_cache: List[torch.Tensor] = []
self.value_shift_cache: List[torch.Tensor] = []
self._seen_tokens = 0
def update_shift_cache(
self,
key_shift_state: torch.Tensor,
value_shift_state: torch.Tensor,
layer_idx,
):
assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache)
self.key_shift_cache.append(key_shift_state)
self.value_shift_cache.append(value_shift_state)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx])
def __len__(self):
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
log_fgate_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}"
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.log_fgate_cache.append(log_fgate_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1)
return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
return None
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], ...]:
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),)
return legacy_cache
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "FgateDynamicCache":
"""
Converts a cache in the legacy cache format into an equivalent FgateDynamicCache.
Args:
past_key_values: Optional legacy cache format
num_layers: Not used in this implementation
Returns:
FgateDynamicCache instance
"""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states, log_fgate_states = past_key_values[layer_idx]
cache.update(key_states, value_states, log_fgate_states, layer_idx)
return cache
def crop(self, max_length: int):
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length]
def batch_split(self, full_batch_size: int, split_size: int) -> List["FgateDynamicCache"]:
out = []
for i in range(0, full_batch_size, split_size):
current_split = FgateDynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache]
out.append(current_split)
return out
@classmethod
def from_batch_splits(cls, splits: List["FgateDynamicCache"]) -> "FgateDynamicCache":
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, layer_log_fgates, idx)
return cache
def batch_repeat_interleave(self, repeats: int):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] |