Delete folder custom_generate/.ipynb_checkpoints with huggingface_hub
Browse files
custom_generate/.ipynb_checkpoints/generate-checkpoint.py
DELETED
|
@@ -1,245 +0,0 @@
|
|
| 1 |
-
# Copyright 2025 China Merchants Bank. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the MIT License (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://mit-license.org
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
from transformers.cache_utils import DynamicCache
|
| 17 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class LagKVCache(DynamicCache):
|
| 21 |
-
"""
|
| 22 |
-
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
|
| 23 |
-
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
|
| 24 |
-
It allows the model to generate with fewer memory resource and faster decoding speed.
|
| 25 |
-
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
|
| 26 |
-
of the SinkCache.
|
| 27 |
-
|
| 28 |
-
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
| 29 |
-
`[batch_size, num_heads, seq_len, head_dim]`.
|
| 30 |
-
|
| 31 |
-
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
|
| 32 |
-
|
| 33 |
-
Parameters:
|
| 34 |
-
_distributed_cache_data:
|
| 35 |
-
Inherited from DynamicCache.
|
| 36 |
-
ratio (`float`):
|
| 37 |
-
The retrain ratio of tokens in the middle chunks.
|
| 38 |
-
sink_size (`int`):
|
| 39 |
-
The number of sink tokens.
|
| 40 |
-
lag_size (`int`):
|
| 41 |
-
The size of the partition. The subsequent partion will serve as a reference for the prior one.
|
| 42 |
-
score_v_ratio (`float`):
|
| 43 |
-
The ratio multiplied to the score of Value states.
|
| 44 |
-
skip_layer_idx (`Optional[List[int]]`):
|
| 45 |
-
A list of layer indices will skip the compression.
|
| 46 |
-
|
| 47 |
-
Example:
|
| 48 |
-
|
| 49 |
-
```python
|
| 50 |
-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
|
| 51 |
-
|
| 52 |
-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 53 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 54 |
-
|
| 55 |
-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
| 56 |
-
|
| 57 |
-
>>> # Prepare a cache class and pass it to model's forward
|
| 58 |
-
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
|
| 59 |
-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
| 60 |
-
>>> outputs.past_key_values # access cache filled with key/values from generation
|
| 61 |
-
LagKVCache()
|
| 62 |
-
```
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
def __init__(
|
| 66 |
-
self,
|
| 67 |
-
_distributed_cache_data=None,
|
| 68 |
-
ratio: float = 0.25,
|
| 69 |
-
sink_size: int = 16,
|
| 70 |
-
lag_size: int = 1024,
|
| 71 |
-
score_v_ratio: float = 1.0,
|
| 72 |
-
skip_layer_idx: Optional[List[int]] = None,
|
| 73 |
-
):
|
| 74 |
-
super().__init__(_distributed_cache_data)
|
| 75 |
-
self.ratio = ratio
|
| 76 |
-
self.sink_size: int = sink_size
|
| 77 |
-
self.lag_size: int = lag_size
|
| 78 |
-
self.score_v_ratio: float = score_v_ratio
|
| 79 |
-
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
|
| 80 |
-
self._compressed_len: List[int] = []
|
| 81 |
-
|
| 82 |
-
def update(
|
| 83 |
-
self,
|
| 84 |
-
key_states: torch.Tensor,
|
| 85 |
-
value_states: torch.Tensor,
|
| 86 |
-
layer_idx: int,
|
| 87 |
-
cache_kwargs=None,
|
| 88 |
-
):
|
| 89 |
-
"""
|
| 90 |
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
| 91 |
-
|
| 92 |
-
Parameters:
|
| 93 |
-
key_states (`torch.Tensor`):
|
| 94 |
-
The new key states to cache.
|
| 95 |
-
value_states (`torch.Tensor`):
|
| 96 |
-
The new value states to cache.
|
| 97 |
-
layer_idx (`int`):
|
| 98 |
-
The index of the layer to cache the states for.
|
| 99 |
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
| 100 |
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
| 101 |
-
|
| 102 |
-
Return:
|
| 103 |
-
A tuple containing the updated key and value states.
|
| 104 |
-
"""
|
| 105 |
-
# Update the number of seen tokens
|
| 106 |
-
if layer_idx == 0:
|
| 107 |
-
self._seen_tokens += key_states.shape[-2]
|
| 108 |
-
|
| 109 |
-
# Update the cache
|
| 110 |
-
if key_states is not None:
|
| 111 |
-
if len(self.key_cache) <= layer_idx:
|
| 112 |
-
# There may be skipped layers, fill them with empty lists
|
| 113 |
-
for _ in range(len(self.key_cache), layer_idx):
|
| 114 |
-
self.key_cache.append([])
|
| 115 |
-
self.value_cache.append([])
|
| 116 |
-
self._compressed_len.append(self.sink_size)
|
| 117 |
-
self.key_cache.append(key_states)
|
| 118 |
-
self.value_cache.append(value_states)
|
| 119 |
-
self._compressed_len.append(self.sink_size)
|
| 120 |
-
elif (
|
| 121 |
-
len(self.key_cache[layer_idx]) == 0
|
| 122 |
-
): # fills previously skipped layers; checking for tensor causes errors
|
| 123 |
-
self.key_cache[layer_idx] = key_states
|
| 124 |
-
self.value_cache[layer_idx] = value_states
|
| 125 |
-
else:
|
| 126 |
-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
| 127 |
-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
| 128 |
-
|
| 129 |
-
if layer_idx not in self.skip_layer_idx:
|
| 130 |
-
return self._compress_kv_by_lag(layer_idx)
|
| 131 |
-
|
| 132 |
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 133 |
-
|
| 134 |
-
def _get_states_score(self, base_len, in_size, end_idx, value):
|
| 135 |
-
"""Partition the states then calculate the state scores"""
|
| 136 |
-
# [batch_size, num_heads, seq_len, head_dim]
|
| 137 |
-
target_v = value[:, :, base_len:end_idx]
|
| 138 |
-
# [batch_size, num_heads, partition_num, lag_size, head_dim]
|
| 139 |
-
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
|
| 140 |
-
ref = target_v[:, :, 1:, :, :]
|
| 141 |
-
v = target_v[:, :, :-1, :, :]
|
| 142 |
-
|
| 143 |
-
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
| 144 |
-
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
| 145 |
-
|
| 146 |
-
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
|
| 147 |
-
|
| 148 |
-
return score
|
| 149 |
-
|
| 150 |
-
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
|
| 151 |
-
# idx is offset by base_len
|
| 152 |
-
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
|
| 153 |
-
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
|
| 154 |
-
return value
|
| 155 |
-
|
| 156 |
-
def _compress_algo(self, layer_idx, base_len):
|
| 157 |
-
"""
|
| 158 |
-
Calculate the scores of KV tokens in each head and partition. See the paper.
|
| 159 |
-
The computation overhead of top-k is significantly reduced by partitioning.
|
| 160 |
-
"""
|
| 161 |
-
in_size = self.key_cache[layer_idx].size()
|
| 162 |
-
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
|
| 163 |
-
# [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
|
| 164 |
-
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
|
| 165 |
-
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
|
| 166 |
-
score = key_score + value_score * self.score_v_ratio
|
| 167 |
-
# you may need to sort the index for some cases
|
| 168 |
-
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
|
| 169 |
-
for i in range(1, selected_idx.size()[2], 1):
|
| 170 |
-
selected_idx[:, :, i] += i * self.lag_size
|
| 171 |
-
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
|
| 172 |
-
new_base_len = base_len + selected_idx.size()[-2]
|
| 173 |
-
# alwarys keep the last window
|
| 174 |
-
tail_len = self.lag_size + in_size[-2] - end_idx
|
| 175 |
-
self.key_cache[layer_idx] = self._modify_kv(
|
| 176 |
-
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
| 177 |
-
)
|
| 178 |
-
self.value_cache[layer_idx] = self._modify_kv(
|
| 179 |
-
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
| 180 |
-
)
|
| 181 |
-
self._compressed_len[layer_idx] = new_base_len
|
| 182 |
-
|
| 183 |
-
def _compress_kv_by_lag(self, layer_idx):
|
| 184 |
-
"""the KV cache will be used then compressed"""
|
| 185 |
-
kv_size = self.key_cache[layer_idx].size()
|
| 186 |
-
base_len = self._compressed_len[layer_idx]
|
| 187 |
-
|
| 188 |
-
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 189 |
-
if kv_size[-2] >= base_len + 2 * self.lag_size:
|
| 190 |
-
self._compress_algo(layer_idx, base_len)
|
| 191 |
-
return keys_to_return, values_to_return
|
| 192 |
-
|
| 193 |
-
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
|
| 194 |
-
"""Custom generate function for LagKVCache.
|
| 195 |
-
(template from https://huggingface.co/transformers-community/sink_cache)
|
| 196 |
-
Args:
|
| 197 |
-
model (`PreTrainedModel`):
|
| 198 |
-
The model to generate from.
|
| 199 |
-
lag_ratio (`float`):
|
| 200 |
-
The retrain ratio of tokens in the middle chunks.
|
| 201 |
-
lag_sink_size (`int`):
|
| 202 |
-
The number of sink tokens.
|
| 203 |
-
lag_size (`int`):
|
| 204 |
-
The size of the partition. See the original paper for more information.
|
| 205 |
-
"""
|
| 206 |
-
# 1. General sanity checks
|
| 207 |
-
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
| 208 |
-
generation_config = kwargs.get("generation_config")
|
| 209 |
-
default_global_generation_config = GenerationConfig()
|
| 210 |
-
default_model_generation_config = model.generation_config
|
| 211 |
-
for arg in UNSUPPORTED_GENERATION_ARGS:
|
| 212 |
-
has_custom_gen_config_arg = (
|
| 213 |
-
generation_config is not None
|
| 214 |
-
# = and not (match global default or match model-specific default)
|
| 215 |
-
and not (
|
| 216 |
-
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
| 217 |
-
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
| 218 |
-
)
|
| 219 |
-
)
|
| 220 |
-
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
| 221 |
-
if kwargs_has_arg or has_custom_gen_config_arg:
|
| 222 |
-
raise ValueError(
|
| 223 |
-
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
| 224 |
-
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# 1.b. The model must be decoder-only
|
| 228 |
-
if model.config.is_encoder_decoder:
|
| 229 |
-
raise ValueError("This custom generate function only works with decoder-only models")
|
| 230 |
-
|
| 231 |
-
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
| 232 |
-
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
| 233 |
-
kwargs.pop("custom_generate", None)
|
| 234 |
-
|
| 235 |
-
# 2. Generate with LagKVCache
|
| 236 |
-
# 2.a. prepare the cache, if it was not passed.
|
| 237 |
-
past_key_values = kwargs.pop("past_key_values", None)
|
| 238 |
-
if past_key_values is None:
|
| 239 |
-
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
|
| 240 |
-
elif not isinstance(past_key_values, LagKVCache):
|
| 241 |
-
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
|
| 242 |
-
|
| 243 |
-
# 2.b. generate with the cache
|
| 244 |
-
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 245 |
-
return generation_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|