from typing import List, Tuple, Optional, Any, Dict import torch class FgateDynamicCache: """ A cache that grows dynamically as more tokens are generated. Custom cache for Forgetting Transformer that does not inherit from transformers.Cache. """ def __init__(self, num_hidden_layers: Optional[int] = None) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self.log_fgate_cache: List[torch.Tensor] = [] self.key_shift_cache: List[torch.Tensor] = [] self.value_shift_cache: List[torch.Tensor] = [] self._seen_tokens = 0 def update_shift_cache( self, key_shift_state: torch.Tensor, value_shift_state: torch.Tensor, layer_idx, ): assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) self.key_shift_cache.append(key_shift_state) self.value_shift_cache.append(value_shift_state) def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: if layer_idx < len(self): return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): for layer_idx in range(len(self)): yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) def __len__(self): return len(self.key_cache) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, log_fgate_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" if layer_idx == 0: self._seen_tokens += key_states.shape[-2] if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states) self.value_cache.append(value_states) self.log_fgate_cache.append(log_fgate_states) else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] def get_max_length(self) -> Optional[int]: return None def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], ...]: legacy_cache = () for layer_idx in range(len(self)): legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) return legacy_cache @classmethod def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "FgateDynamicCache": """ Converts a cache in the legacy cache format into an equivalent FgateDynamicCache. Args: past_key_values: Optional legacy cache format num_layers: Not used in this implementation Returns: FgateDynamicCache instance """ cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states, log_fgate_states = past_key_values[layer_idx] cache.update(key_states, value_states, log_fgate_states, layer_idx) return cache def crop(self, max_length: int): if max_length < 0: max_length = self.get_seq_length() - abs(max_length) if self.get_seq_length() <= max_length: return self._seen_tokens = max_length for idx in range(len(self.key_cache)): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] def batch_split(self, full_batch_size: int, split_size: int) -> List["FgateDynamicCache"]: out = [] for i in range(0, full_batch_size, split_size): current_split = FgateDynamicCache() current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] out.append(current_split) return out @classmethod def from_batch_splits(cls, splits: List["FgateDynamicCache"]) -> "FgateDynamicCache": cache = cls() for idx in range(len(splits[0])): layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) cache.update(layer_keys, layer_values, layer_log_fgates, idx) return cache def batch_repeat_interleave(self, repeats: int): for layer_idx in range(len(self)): self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor): for layer_idx in range(len(self)): self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...]