| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from transformers import DynamicCache, GenerationConfig |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
| class LagKVCache(DynamicCache): |
| """ |
| A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704). |
| The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle. |
| It allows the model to generate with fewer memory resource and faster decoding speed. |
| The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss |
| of the SinkCache. |
| |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
| `[batch_size, num_heads, seq_len, head_dim]`. |
| |
| For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV. |
| |
| Parameters: |
| _distributed_cache_data: |
| Inherited from DynamicCache. |
| ratio (`float`): |
| The retrain ratio of tokens in the middle chunks. |
| sink_size (`int`): |
| The number of sink tokens. |
| lag_size (`int`): |
| The size of the partition. The subsequent partion will serve as a reference for the prior one. |
| score_v_ratio (`float`): |
| The ratio multiplied to the score of Value states. |
| skip_layer_idx (`Optional[List[int]]`): |
| A list of layer indices will skip the compression. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache |
| |
| >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| |
| >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
| |
| >>> # Prepare a cache class and pass it to model's forward |
| >>> past_key_values = LagKVCache(ratio=0.25, lag_size=128) |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| >>> outputs.past_key_values # access cache filled with key/values from generation |
| LagKVCache() |
| ``` |
| """ |
|
|
| def __init__( |
| self, |
| _distributed_cache_data=None, |
| ratio: float = 0.25, |
| sink_size: int = 16, |
| lag_size: int = 1024, |
| score_v_ratio: float = 1.0, |
| skip_layer_idx: Optional[List[int]] = None, |
| ): |
| super().__init__(_distributed_cache_data) |
| self.ratio = ratio |
| self.sink_size: int = sink_size |
| self.lag_size: int = lag_size |
| self.score_v_ratio: float = score_v_ratio |
| self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else [] |
| self._compressed_len: List[int] = [] |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs=None, |
| ): |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| |
| if key_states is not None: |
| if len(self.key_cache) <= layer_idx: |
| |
| for _ in range(len(self.key_cache), layer_idx): |
| self.key_cache.append([]) |
| self.value_cache.append([]) |
| self._compressed_len.append(self.sink_size) |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
| self._compressed_len.append(self.sink_size) |
| elif ( |
| len(self.key_cache[layer_idx]) == 0 |
| ): |
| self.key_cache[layer_idx] = key_states |
| self.value_cache[layer_idx] = value_states |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| if layer_idx not in self.skip_layer_idx: |
| return self._compress_kv_by_lag(layer_idx) |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
| def _get_states_score(self, base_len, in_size, end_idx, value): |
| """Partition the states then calculate the state scores""" |
| |
| target_v = value[:, :, base_len:end_idx] |
| |
| target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1]) |
| ref = target_v[:, :, 1:, :, :] |
| v = target_v[:, :, :-1, :, :] |
|
|
| min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
| max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
|
|
| score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1) |
|
|
| return score |
|
|
| def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len): |
| |
| selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx) |
| value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2) |
| return value |
|
|
| def _compress_algo(self, layer_idx, base_len): |
| """ |
| Calculate the scores of KV tokens in each head and partition. See the paper. |
| The computation overhead of top-k is significantly reduced by partitioning. |
| """ |
| in_size = self.key_cache[layer_idx].size() |
| end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size |
| |
| key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx]) |
| value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx]) |
| score = key_score + value_score * self.score_v_ratio |
| |
| selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices |
| for i in range(1, selected_idx.size()[2], 1): |
| selected_idx[:, :, i] += i * self.lag_size |
| selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1]) |
| new_base_len = base_len + selected_idx.size()[-2] |
| |
| tail_len = self.lag_size + in_size[-2] - end_idx |
| self.key_cache[layer_idx] = self._modify_kv( |
| self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
| ) |
| self.value_cache[layer_idx] = self._modify_kv( |
| self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
| ) |
| self._compressed_len[layer_idx] = new_base_len |
|
|
| def _compress_kv_by_lag(self, layer_idx): |
| """the KV cache will be used then compressed""" |
| kv_size = self.key_cache[layer_idx].size() |
| base_len = self._compressed_len[layer_idx] |
|
|
| keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| if kv_size[-2] >= base_len + 2 * self.lag_size: |
| self._compress_algo(layer_idx, base_len) |
| return keys_to_return, values_to_return |
|
|
| def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs): |
| """Custom generate function for LagKVCache. |
| (template from https://huggingface.co/transformers-community/sink_cache) |
| Args: |
| model (`PreTrainedModel`): |
| The model to generate from. |
| lag_ratio (`float`): |
| The retrain ratio of tokens in the middle chunks. |
| lag_sink_size (`int`): |
| The number of sink tokens. |
| lag_size (`int`): |
| The size of the partition. See the original paper for more information. |
| """ |
| |
| |
| generation_config = kwargs.get("generation_config") |
| default_global_generation_config = GenerationConfig() |
| default_model_generation_config = model.generation_config |
| for arg in UNSUPPORTED_GENERATION_ARGS: |
| has_custom_gen_config_arg = ( |
| generation_config is not None |
| |
| and not ( |
| getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
| or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
| ) |
| ) |
| kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
| if kwargs_has_arg or has_custom_gen_config_arg: |
| raise ValueError( |
| f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
| f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
| ) |
|
|
| |
| if model.config.is_encoder_decoder: |
| raise ValueError("This custom generate function only works with decoder-only models") |
|
|
| |
| |
| kwargs.pop("custom_generate", None) |
|
|
| |
| |
| past_key_values = kwargs.pop("past_key_values", None) |
| if past_key_values is None: |
| past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size) |
| elif not isinstance(past_key_values, LagKVCache): |
| raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance") |
|
|
| |
| generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
| return generation_outputs |
|
|