File size: 5,806 Bytes
af8fa42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from typing import Optional, Tuple, List, Dict, Any
import torch
from transformers.cache_utils import Cache
class HybridCache(Cache):
"""
A cache for hybrid contextual states. Some layers are attention, some layers are RNNs.
"""
def __init__(
self,
seen_tokens: int = 0,
):
self.states: List[Dict[str, Any]] = []
self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
@property
def is_compileable(self) -> bool:
return False
@property
def seen_tokens(self) -> int:
return self._seen_tokens
def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
if layer_idx < len(self):
return self.states[layer_idx]
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
for state in self.states:
yield state
def __len__(self):
return len(self.states)
def update(
self,
recurrent_state: torch.Tensor | None = None,
attn_state: Tuple[torch.Tensor, torch.Tensor] | None = None,
conv_state: Tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None = None,
ffn_state: torch.Tensor | None = None,
layer_idx: int = 0,
offset: Optional[int] = 1,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
Args:
recurrent_state (`torch.Tensor`, `optional`):
The new recurrent state to cache.
attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
The new attention key/value states to cache.
conv_state (`Tuple[torch.Tensor]`, `optional`):
The new convolution state to cache.
layer_idx (`int`, defaults to 0):
The index of the layer to cache the states for.
offset (`int`, `optional`, defaults to 1):
The number of new tokens being processed.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass.
Return:
Dictionary of the updated state.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += offset
if attn_state is not None:
input_size = attn_state[0].shape[-2]
if cache_kwargs is not None:
window_size = cache_kwargs.get('window_size', None)
else:
window_size = None
if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
if len(self.states) <= layer_idx:
if attn_state is not None:
if window_size is not None and input_size > window_size:
attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
attn_state[1][..., -window_size:, :].contiguous())
state = dict(
recurrent_state=recurrent_state,
attn_state=attn_state,
conv_state=conv_state,
ffn_state=ffn_state
)
self.states.append(state)
else:
state = self.states[layer_idx]
if recurrent_state is not None:
state['recurrent_state'] = recurrent_state
if attn_state is not None:
key_state, value_state = state['attn_state']
if window_size is not None and key_state.shape[-2] == window_size:
# DO NOT allocate new memory if the cache is full
# roll the key/value states to the left by `input_size`
key_state = key_state.roll(-input_size, -2)
value_state = value_state.roll(-input_size, -2)
# replace the last `input_size` tokens with the new key/value states
key_state[..., -input_size:, :] = attn_state[0]
value_state[..., -input_size:, :] = attn_state[1]
attn_state = (key_state, value_state)
else:
attn_state = (torch.cat([key_state, attn_state[0]], -2),
torch.cat([value_state, attn_state[1]], -2),)
state['attn_state'] = attn_state
if conv_state is not None:
state['conv_state'] = conv_state
if ffn_state is not None:
state['ffn_state'] = ffn_state
return state
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.states) <= layer_idx:
return 0
return self._seen_tokens
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
return None
def to_legacy_cache(self) -> Tuple:
return tuple(self.states)
@classmethod
def from_legacy_cache(
cls,
past_key_values: Optional[Tuple] = None,
seen_tokens: int = 0
) -> "HybridCache":
"""Converts a cache in the legacy cache format into an equivalent `Cache`."""
cache = cls(seen_tokens)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
cache.states.append(past_key_values[layer_idx])
return cache
|