|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
|
|
|
class Cache(transformers.cache_utils.Cache): |
|
|
""" |
|
|
A cache used for storing hidden states produced by flash linear attention models. |
|
|
|
|
|
It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
seen_tokens: int = 0 |
|
|
) -> Cache: |
|
|
|
|
|
self.states: List[torch.Tensor] = [] |
|
|
self._seen_tokens = seen_tokens |
|
|
|
|
|
def __getitem__(self, layer_idx: int) -> torch.Tensor: |
|
|
if layer_idx < len(self): |
|
|
return self.states[layer_idx] |
|
|
else: |
|
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
|
|
def __iter__(self): |
|
|
for state in self.states: |
|
|
yield state |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.states) |
|
|
|
|
|
def update( |
|
|
self, |
|
|
state: Tuple[torch.Tensor], |
|
|
layer_idx: int, |
|
|
offset: Optional[int] = 1, |
|
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Tuple[torch.Tensor]: |
|
|
""" |
|
|
Updates the cache with the new `state` for the layer `layer_idx`. |
|
|
|
|
|
Parameters: |
|
|
state (`Tuple[torch.Tensor]`): |
|
|
The new state to cache. |
|
|
layer_idx (`int`): |
|
|
The index of the layer to cache the states for. |
|
|
offset (`int`): |
|
|
The offset of current fed tokens. |
|
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
|
Additional arguments for the cache subclass. |
|
|
|
|
|
Return: |
|
|
The updated state. |
|
|
""" |
|
|
|
|
|
if isinstance(state, torch.Tensor): |
|
|
state = (state,) |
|
|
if len(self.states) <= layer_idx: |
|
|
self.states.append(state) |
|
|
else: |
|
|
for i, s in enumerate(state): |
|
|
self.states[layer_idx][i].copy_(s) |
|
|
|
|
|
if layer_idx == len(self) - 1: |
|
|
self._seen_tokens += offset |
|
|
|
|
|
return state |
|
|
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
if len(self.states) <= layer_idx: |
|
|
return 0 |
|
|
return self._seen_tokens |
|
|
|
|
|
def get_max_length(self) -> Optional[int]: |
|
|
"""Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" |
|
|
return None |
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
|
for layer_idx in range(len(self.states)): |
|
|
device = self.states[layer_idx].device |
|
|
self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
|
|
def to_legacy_cache(self) -> Tuple[torch.Tensor]: |
|
|
return tuple(self.states) |
|
|
|
|
|
@classmethod |
|
|
def from_legacy_cache( |
|
|
cls, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
seen_tokens: int = 0 |
|
|
) -> Cache: |
|
|
"""Converts a cache in the legacy cache format into an equivalent `Cache`.""" |
|
|
|
|
|
cache = cls(seen_tokens) |
|
|
if past_key_values is not None: |
|
|
for layer_idx in range(len(past_key_values)): |
|
|
cache.update(past_key_values[layer_idx], layer_idx) |
|
|
return cache |
|
|
|