forgetting_pile_2layer / fgate_cache.py
Lanni-ni's picture
add remote code + model files
15063d0 verified
from typing import List, Tuple, Optional, Any, Dict
import torch
class FgateDynamicCache:
"""
A cache that grows dynamically as more tokens are generated.
Custom cache for Forgetting Transformer that does not inherit from transformers.Cache.
"""
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.log_fgate_cache: List[torch.Tensor] = []
self.key_shift_cache: List[torch.Tensor] = []
self.value_shift_cache: List[torch.Tensor] = []
self._seen_tokens = 0
def update_shift_cache(
self,
key_shift_state: torch.Tensor,
value_shift_state: torch.Tensor,
layer_idx,
):
assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache)
self.key_shift_cache.append(key_shift_state)
self.value_shift_cache.append(value_shift_state)
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx])
def __len__(self):
return len(self.key_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
log_fgate_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}"
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.log_fgate_cache.append(log_fgate_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1)
return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
return None
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], ...]:
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),)
return legacy_cache
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "FgateDynamicCache":
"""
Converts a cache in the legacy cache format into an equivalent FgateDynamicCache.
Args:
past_key_values: Optional legacy cache format
num_layers: Not used in this implementation
Returns:
FgateDynamicCache instance
"""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states, log_fgate_states = past_key_values[layer_idx]
cache.update(key_states, value_states, log_fgate_states, layer_idx)
return cache
def crop(self, max_length: int):
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length]
def batch_split(self, full_batch_size: int, split_size: int) -> List["FgateDynamicCache"]:
out = []
for i in range(0, full_batch_size, split_size):
current_split = FgateDynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache]
out.append(current_split)
return out
@classmethod
def from_batch_splits(cls, splits: List["FgateDynamicCache"]) -> "FgateDynamicCache":
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, layer_log_fgates, idx)
return cache
def batch_repeat_interleave(self, repeats: int):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...]