kyleliang commited on
Commit
603586d
·
verified ·
1 Parent(s): 820a56d

Delete folder custom_generate/.ipynb_checkpoints with huggingface_hub

Browse files
custom_generate/.ipynb_checkpoints/generate-checkpoint.py DELETED
@@ -1,245 +0,0 @@
1
- # Copyright 2025 China Merchants Bank. All rights reserved.
2
- #
3
- # Licensed under the MIT License (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://mit-license.org
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from transformers.cache_utils import DynamicCache
17
- from typing import Any, Dict, List, Optional, Tuple
18
-
19
-
20
- class LagKVCache(DynamicCache):
21
- """
22
- A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
23
- The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
24
- It allows the model to generate with fewer memory resource and faster decoding speed.
25
- The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
26
- of the SinkCache.
27
-
28
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
29
- `[batch_size, num_heads, seq_len, head_dim]`.
30
-
31
- For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
32
-
33
- Parameters:
34
- _distributed_cache_data:
35
- Inherited from DynamicCache.
36
- ratio (`float`):
37
- The retrain ratio of tokens in the middle chunks.
38
- sink_size (`int`):
39
- The number of sink tokens.
40
- lag_size (`int`):
41
- The size of the partition. The subsequent partion will serve as a reference for the prior one.
42
- score_v_ratio (`float`):
43
- The ratio multiplied to the score of Value states.
44
- skip_layer_idx (`Optional[List[int]]`):
45
- A list of layer indices will skip the compression.
46
-
47
- Example:
48
-
49
- ```python
50
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
51
-
52
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
53
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
54
-
55
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
56
-
57
- >>> # Prepare a cache class and pass it to model's forward
58
- >>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
59
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
60
- >>> outputs.past_key_values # access cache filled with key/values from generation
61
- LagKVCache()
62
- ```
63
- """
64
-
65
- def __init__(
66
- self,
67
- _distributed_cache_data=None,
68
- ratio: float = 0.25,
69
- sink_size: int = 16,
70
- lag_size: int = 1024,
71
- score_v_ratio: float = 1.0,
72
- skip_layer_idx: Optional[List[int]] = None,
73
- ):
74
- super().__init__(_distributed_cache_data)
75
- self.ratio = ratio
76
- self.sink_size: int = sink_size
77
- self.lag_size: int = lag_size
78
- self.score_v_ratio: float = score_v_ratio
79
- self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
80
- self._compressed_len: List[int] = []
81
-
82
- def update(
83
- self,
84
- key_states: torch.Tensor,
85
- value_states: torch.Tensor,
86
- layer_idx: int,
87
- cache_kwargs=None,
88
- ):
89
- """
90
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
91
-
92
- Parameters:
93
- key_states (`torch.Tensor`):
94
- The new key states to cache.
95
- value_states (`torch.Tensor`):
96
- The new value states to cache.
97
- layer_idx (`int`):
98
- The index of the layer to cache the states for.
99
- cache_kwargs (`Dict[str, Any]`, `optional`):
100
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
101
-
102
- Return:
103
- A tuple containing the updated key and value states.
104
- """
105
- # Update the number of seen tokens
106
- if layer_idx == 0:
107
- self._seen_tokens += key_states.shape[-2]
108
-
109
- # Update the cache
110
- if key_states is not None:
111
- if len(self.key_cache) <= layer_idx:
112
- # There may be skipped layers, fill them with empty lists
113
- for _ in range(len(self.key_cache), layer_idx):
114
- self.key_cache.append([])
115
- self.value_cache.append([])
116
- self._compressed_len.append(self.sink_size)
117
- self.key_cache.append(key_states)
118
- self.value_cache.append(value_states)
119
- self._compressed_len.append(self.sink_size)
120
- elif (
121
- len(self.key_cache[layer_idx]) == 0
122
- ): # fills previously skipped layers; checking for tensor causes errors
123
- self.key_cache[layer_idx] = key_states
124
- self.value_cache[layer_idx] = value_states
125
- else:
126
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
127
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
128
-
129
- if layer_idx not in self.skip_layer_idx:
130
- return self._compress_kv_by_lag(layer_idx)
131
-
132
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
133
-
134
- def _get_states_score(self, base_len, in_size, end_idx, value):
135
- """Partition the states then calculate the state scores"""
136
- # [batch_size, num_heads, seq_len, head_dim]
137
- target_v = value[:, :, base_len:end_idx]
138
- # [batch_size, num_heads, partition_num, lag_size, head_dim]
139
- target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
140
- ref = target_v[:, :, 1:, :, :]
141
- v = target_v[:, :, :-1, :, :]
142
-
143
- min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
144
- max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
145
-
146
- score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
147
-
148
- return score
149
-
150
- def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
151
- # idx is offset by base_len
152
- selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
153
- value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
154
- return value
155
-
156
- def _compress_algo(self, layer_idx, base_len):
157
- """
158
- Calculate the scores of KV tokens in each head and partition. See the paper.
159
- The computation overhead of top-k is significantly reduced by partitioning.
160
- """
161
- in_size = self.key_cache[layer_idx].size()
162
- end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
163
- # [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
164
- key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
165
- value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
166
- score = key_score + value_score * self.score_v_ratio
167
- # you may need to sort the index for some cases
168
- selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
169
- for i in range(1, selected_idx.size()[2], 1):
170
- selected_idx[:, :, i] += i * self.lag_size
171
- selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
172
- new_base_len = base_len + selected_idx.size()[-2]
173
- # alwarys keep the last window
174
- tail_len = self.lag_size + in_size[-2] - end_idx
175
- self.key_cache[layer_idx] = self._modify_kv(
176
- self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
177
- )
178
- self.value_cache[layer_idx] = self._modify_kv(
179
- self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
180
- )
181
- self._compressed_len[layer_idx] = new_base_len
182
-
183
- def _compress_kv_by_lag(self, layer_idx):
184
- """the KV cache will be used then compressed"""
185
- kv_size = self.key_cache[layer_idx].size()
186
- base_len = self._compressed_len[layer_idx]
187
-
188
- keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
189
- if kv_size[-2] >= base_len + 2 * self.lag_size:
190
- self._compress_algo(layer_idx, base_len)
191
- return keys_to_return, values_to_return
192
-
193
- def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
194
- """Custom generate function for LagKVCache.
195
- (template from https://huggingface.co/transformers-community/sink_cache)
196
- Args:
197
- model (`PreTrainedModel`):
198
- The model to generate from.
199
- lag_ratio (`float`):
200
- The retrain ratio of tokens in the middle chunks.
201
- lag_sink_size (`int`):
202
- The number of sink tokens.
203
- lag_size (`int`):
204
- The size of the partition. See the original paper for more information.
205
- """
206
- # 1. General sanity checks
207
- # 1.a. A few arguments are not allowed, especially arguments that control caches.
208
- generation_config = kwargs.get("generation_config")
209
- default_global_generation_config = GenerationConfig()
210
- default_model_generation_config = model.generation_config
211
- for arg in UNSUPPORTED_GENERATION_ARGS:
212
- has_custom_gen_config_arg = (
213
- generation_config is not None
214
- # = and not (match global default or match model-specific default)
215
- and not (
216
- getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
217
- or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
218
- )
219
- )
220
- kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
221
- if kwargs_has_arg or has_custom_gen_config_arg:
222
- raise ValueError(
223
- f"`{arg}` is set, but it's not supported in this custom generate function. List of "
224
- f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
225
- )
226
-
227
- # 1.b. The model must be decoder-only
228
- if model.config.is_encoder_decoder:
229
- raise ValueError("This custom generate function only works with decoder-only models")
230
-
231
- # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
232
- # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
233
- kwargs.pop("custom_generate", None)
234
-
235
- # 2. Generate with LagKVCache
236
- # 2.a. prepare the cache, if it was not passed.
237
- past_key_values = kwargs.pop("past_key_values", None)
238
- if past_key_values is None:
239
- past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
240
- elif not isinstance(past_key_values, LagKVCache):
241
- raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
242
-
243
- # 2.b. generate with the cache
244
- generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
245
- return generation_outputs