|
|
from typing import Any, Dict, Optional, Tuple |
|
|
from transformers.cache_utils import DynamicCache |
|
|
import torch |
|
|
|
|
|
|
|
|
class DynamicCacheWithQuery(DynamicCache): |
|
|
""" |
|
|
Cache class used for QRRetriever; |
|
|
LJN: put the query states in the cache_kwargs to keep the same signature as DynamicCache |
|
|
LJN: please take the query states from the cache_kwargs |
|
|
""" |
|
|
|
|
|
def __init__(self, query_indices=[]) -> None: |
|
|
super().__init__() |
|
|
self._query_indices = query_indices |
|
|
self.query_cache = [] |
|
|
|
|
|
def update( |
|
|
self, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
layer_idx: int, |
|
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
|
|
Parameters: |
|
|
key_states (`torch.Tensor`): |
|
|
The new key states to cache. |
|
|
value_states (`torch.Tensor`): |
|
|
The new value states to cache. |
|
|
layer_idx (`int`): |
|
|
The index of the layer to cache the states for. |
|
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
|
|
Return: |
|
|
A tuple containing the updated key and value states. |
|
|
""" |
|
|
|
|
|
if layer_idx == 0: |
|
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
|
|
|
if key_states is not None: |
|
|
if len(self.key_cache) <= layer_idx: |
|
|
|
|
|
for _ in range(len(self.key_cache), layer_idx): |
|
|
self.key_cache.append(torch.tensor([])) |
|
|
self.value_cache.append(torch.tensor([])) |
|
|
self.key_cache.append(key_states) |
|
|
self.value_cache.append(value_states) |
|
|
elif ( |
|
|
not self.key_cache[layer_idx].numel() |
|
|
): |
|
|
self.key_cache[layer_idx] = key_states |
|
|
self.value_cache[layer_idx] = value_states |
|
|
else: |
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
|
|
if cache_kwargs is not None: |
|
|
query_states = cache_kwargs.get("query_states", None) |
|
|
else: |
|
|
query_states = None |
|
|
if query_states is not None: |
|
|
if len(self.query_cache) <= layer_idx: |
|
|
self.query_cache.append(query_states) |
|
|
else: |
|
|
self.query_cache[layer_idx] = torch.cat([self.query_cache[layer_idx], query_states], dim=-2) |
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
|
|
@classmethod |
|
|
def from_legacy_cache_with_query_indices(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, query_indices = []) -> "DynamicCache": |
|
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for |
|
|
backward compatibility.""" |
|
|
cache = cls(query_indices=query_indices) |
|
|
if past_key_values is not None: |
|
|
for layer_idx in range(len(past_key_values)): |
|
|
key_states, value_states = past_key_values[layer_idx] |
|
|
cache.update(key_states, value_states, layer_idx) |
|
|
return cache |
|
|
|