| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from transformers import DynamicCache, GenerationConfig |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| |
|
| | UNSUPPORTED_GENERATION_ARGS = [ |
| | "cache_implementation", |
| | "cache_config", |
| | "return_legacy_cache", |
| | "num_beams", |
| | "compile_config", |
| | "assistant_model", |
| | ] |
| |
|
| | class LagKVCache(DynamicCache): |
| | """ |
| | A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704). |
| | The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle. |
| | It allows the model to generate with fewer memory resource and faster decoding speed. |
| | The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss |
| | of the SinkCache. |
| | |
| | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
| | `[batch_size, num_heads, seq_len, head_dim]`. |
| | |
| | For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV. |
| | |
| | Parameters: |
| | _distributed_cache_data: |
| | Inherited from DynamicCache. |
| | ratio (`float`): |
| | The retrain ratio of tokens in the middle chunks. |
| | sink_size (`int`): |
| | The number of sink tokens. |
| | lag_size (`int`): |
| | The size of the partition. The subsequent partion will serve as a reference for the prior one. |
| | score_v_ratio (`float`): |
| | The ratio multiplied to the score of Value states. |
| | skip_layer_idx (`Optional[List[int]]`): |
| | A list of layer indices will skip the compression. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache |
| | |
| | >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| | >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| | |
| | >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
| | |
| | >>> # Prepare a cache class and pass it to model's forward |
| | >>> past_key_values = LagKVCache(ratio=0.25, lag_size=128) |
| | >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| | >>> outputs.past_key_values # access cache filled with key/values from generation |
| | LagKVCache() |
| | ``` |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | _distributed_cache_data=None, |
| | ratio: float = 0.25, |
| | sink_size: int = 16, |
| | lag_size: int = 1024, |
| | score_v_ratio: float = 1.0, |
| | skip_layer_idx: Optional[List[int]] = None, |
| | ): |
| | super().__init__(_distributed_cache_data) |
| | self.ratio = ratio |
| | self.sink_size: int = sink_size |
| | self.lag_size: int = lag_size |
| | self.score_v_ratio: float = score_v_ratio |
| | self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else [] |
| | self._compressed_len: List[int] = [] |
| |
|
| | def update( |
| | self, |
| | key_states: torch.Tensor, |
| | value_states: torch.Tensor, |
| | layer_idx: int, |
| | cache_kwargs=None, |
| | ): |
| | """ |
| | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| | |
| | Parameters: |
| | key_states (`torch.Tensor`): |
| | The new key states to cache. |
| | value_states (`torch.Tensor`): |
| | The new value states to cache. |
| | layer_idx (`int`): |
| | The index of the layer to cache the states for. |
| | cache_kwargs (`Dict[str, Any]`, `optional`): |
| | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
| | |
| | Return: |
| | A tuple containing the updated key and value states. |
| | """ |
| | |
| | if layer_idx == 0: |
| | self._seen_tokens += key_states.shape[-2] |
| |
|
| | |
| | if key_states is not None: |
| | if len(self.key_cache) <= layer_idx: |
| | |
| | for _ in range(len(self.key_cache), layer_idx): |
| | self.key_cache.append([]) |
| | self.value_cache.append([]) |
| | self._compressed_len.append(self.sink_size) |
| | self.key_cache.append(key_states) |
| | self.value_cache.append(value_states) |
| | self._compressed_len.append(self.sink_size) |
| | elif ( |
| | len(self.key_cache[layer_idx]) == 0 |
| | ): |
| | self.key_cache[layer_idx] = key_states |
| | self.value_cache[layer_idx] = value_states |
| | else: |
| | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| |
|
| | if layer_idx not in self.skip_layer_idx: |
| | return self._compress_kv_by_lag(layer_idx) |
| |
|
| | return self.key_cache[layer_idx], self.value_cache[layer_idx] |
| |
|
| | def _get_states_score(self, base_len, in_size, end_idx, value): |
| | """Partition the states then calculate the state scores""" |
| | |
| | target_v = value[:, :, base_len:end_idx] |
| | |
| | target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1]) |
| | ref = target_v[:, :, 1:, :, :] |
| | v = target_v[:, :, :-1, :, :] |
| |
|
| | min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
| | max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
| |
|
| | score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1) |
| |
|
| | return score |
| |
|
| | def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len): |
| | |
| | selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx) |
| | value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2) |
| | return value |
| |
|
| | def _compress_algo(self, layer_idx, base_len): |
| | """ |
| | Calculate the scores of KV tokens in each head and partition. See the paper. |
| | The computation overhead of top-k is significantly reduced by partitioning. |
| | """ |
| | in_size = self.key_cache[layer_idx].size() |
| | end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size |
| | |
| | key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx]) |
| | value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx]) |
| | score = key_score + value_score * self.score_v_ratio |
| | |
| | selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices |
| | for i in range(1, selected_idx.size()[2], 1): |
| | selected_idx[:, :, i] += i * self.lag_size |
| | selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1]) |
| | new_base_len = base_len + selected_idx.size()[-2] |
| | |
| | tail_len = self.lag_size + in_size[-2] - end_idx |
| | self.key_cache[layer_idx] = self._modify_kv( |
| | self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
| | ) |
| | self.value_cache[layer_idx] = self._modify_kv( |
| | self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
| | ) |
| | self._compressed_len[layer_idx] = new_base_len |
| |
|
| | def _compress_kv_by_lag(self, layer_idx): |
| | """the KV cache will be used then compressed""" |
| | kv_size = self.key_cache[layer_idx].size() |
| | base_len = self._compressed_len[layer_idx] |
| |
|
| | keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx] |
| | if kv_size[-2] >= base_len + 2 * self.lag_size: |
| | self._compress_algo(layer_idx, base_len) |
| | return keys_to_return, values_to_return |
| |
|
| | def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs): |
| | """Custom generate function for LagKVCache. |
| | (template from https://huggingface.co/transformers-community/sink_cache) |
| | Args: |
| | model (`PreTrainedModel`): |
| | The model to generate from. |
| | lag_ratio (`float`): |
| | The retrain ratio of tokens in the middle chunks. |
| | lag_sink_size (`int`): |
| | The number of sink tokens. |
| | lag_size (`int`): |
| | The size of the partition. See the original paper for more information. |
| | """ |
| | |
| | |
| | generation_config = kwargs.get("generation_config") |
| | default_global_generation_config = GenerationConfig() |
| | default_model_generation_config = model.generation_config |
| | for arg in UNSUPPORTED_GENERATION_ARGS: |
| | has_custom_gen_config_arg = ( |
| | generation_config is not None |
| | |
| | and not ( |
| | getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
| | or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
| | ) |
| | ) |
| | kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
| | if kwargs_has_arg or has_custom_gen_config_arg: |
| | raise ValueError( |
| | f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
| | f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
| | ) |
| |
|
| | |
| | if model.config.is_encoder_decoder: |
| | raise ValueError("This custom generate function only works with decoder-only models") |
| |
|
| | |
| | |
| | kwargs.pop("custom_generate", None) |
| |
|
| | |
| | |
| | past_key_values = kwargs.pop("past_key_values", None) |
| | if past_key_values is None: |
| | past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size) |
| | elif not isinstance(past_key_values, LagKVCache): |
| | raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance") |
| |
|
| | |
| | generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
| | return generation_outputs |
| |
|