File size: 3,726 Bytes
091400d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Any, Dict, Optional, Tuple
from transformers.cache_utils import DynamicCache
import torch
class DynamicCacheWithQuery(DynamicCache):
"""
Cache class used for QRRetriever;
LJN: put the query states in the cache_kwargs to keep the same signature as DynamicCache
LJN: please take the query states from the cache_kwargs
"""
def __init__(self, query_indices=[]) -> None:
super().__init__()
self._query_indices = query_indices # indices for query vectors to save
self.query_cache = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if cache_kwargs is not None:
query_states = cache_kwargs.get("query_states", None)
else:
query_states = None
if query_states is not None:
if len(self.query_cache) <= layer_idx:
self.query_cache.append(query_states)
else:
self.query_cache[layer_idx] = torch.cat([self.query_cache[layer_idx], query_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
@classmethod
def from_legacy_cache_with_query_indices(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, query_indices = []) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(query_indices=query_indices)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
|