ligongh commited on
Commit
5a96d7b
·
verified ·
1 Parent(s): 2482e1b

Upload generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +447 -0
custom_generate/generate.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import importlib.metadata
3
+ import json
4
+ import os
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from packaging import version
11
+
12
+ from transformers.utils import is_hqq_available, is_optimum_quanto_available, logging
13
+
14
+ from transformers.cache_utils import CacheConfig, QuantizedCacheConfig, QuantizedCache
15
+
16
+ if is_hqq_available():
17
+ from hqq.core.quantize import Quantizer as HQQQuantizer
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ @dataclass
22
+ class SQuatCacheConfig(QuantizedCacheConfig):
23
+ """
24
+ Configuration class for SQuat cache settings.
25
+ """
26
+ def __init__(self,
27
+ quant_group_size: Optional[int] = 64,
28
+ squat_lambda: Optional[float] = 0.0001,
29
+ subspace_dim: Optional[int] = 5,
30
+ shared_svd: Optional[bool] = True,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.cache_implementation = "squat"
35
+ self.quant_group_size = quant_group_size
36
+ self.squat_lambda = squat_lambda
37
+ self.subspace_dim = subspace_dim
38
+ self.shared_svd = shared_svd
39
+
40
+
41
+ class SQuatCache(QuantizedCache):
42
+ """
43
+ Quantized Cache class that uses `SQuat` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
44
+
45
+ Parameters:
46
+ cache_config (`SQuatCacheConfig`):
47
+ A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
48
+
49
+ Example:
50
+
51
+ ```python
52
+ >>> # Run pip install quanto first if you don't have it yet
53
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SQuatCache, SQuatCacheConfig
54
+
55
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
56
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
57
+
58
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
59
+
60
+ >>> # Prepare a cache class and pass it to model's forward
61
+ >>> cache_config = SQuatCacheConfig(nbits=4)
62
+ >>> past_key_values = SQuatCache(cache_config=cache_config)
63
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
64
+ >>> outputs.past_key_values # access cache filled with key/values from generation
65
+ SQuatCache()
66
+ ```
67
+ """
68
+
69
+ def __init__(self, cache_config: CacheConfig) -> None:
70
+ super().__init__(cache_config)
71
+
72
+ if is_optimum_quanto_available():
73
+ optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto"))
74
+ if optimum_quanto_version <= version.parse("0.2.5"):
75
+ raise ImportError(
76
+ f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
77
+ )
78
+ from optimum.quanto import MaxOptimizer, qint2, qint4
79
+
80
+ if self.nbits not in [2, 4]:
81
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
82
+
83
+ if self.axis_key not in [0, -1]:
84
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
85
+
86
+ if self.axis_value not in [0, -1]:
87
+ raise ValueError(
88
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
89
+ )
90
+
91
+ self.qtype = qint4 if self.nbits == 4 else qint2
92
+ self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
93
+
94
+ self.auxiliary_matrices_A = []
95
+ self.auxiliary_matrices_P = []
96
+ self.squat_lambda = getattr(cache_config, "squat_lambda", 0.0005)
97
+ self.squat_q_group_size = getattr(cache_config, "quant_group_size", 64)
98
+ self.squat_subspace_dim = getattr(cache_config, "subspace_dim", 20)
99
+ self.squat_shared_svd = getattr(cache_config, "shared_svd", True)
100
+
101
+ def update(
102
+ self,
103
+ key_states: torch.Tensor,
104
+ value_states: torch.Tensor,
105
+ layer_idx: int,
106
+ cache_kwargs: Optional[Dict[str, Any]] = None,
107
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ # Update the number of seen tokens
109
+ if layer_idx == 0:
110
+ self._seen_tokens += key_states.shape[-2]
111
+
112
+ if len(self.key_cache) < layer_idx:
113
+ raise ValueError("SQuatCache does not support model usage where layers are skipped. Use DynamicCache.")
114
+ elif len(self.key_cache) == layer_idx: # prefilling
115
+ if len(self.auxiliary_matrices_A) == layer_idx:
116
+ Ainv_t, P_inv = self._get_query_subspace(key_states, cache_kwargs["query_states"], cache_kwargs["attention_mask"])
117
+ self.auxiliary_matrices_A.append(Ainv_t)
118
+ self.auxiliary_matrices_P.append(P_inv)
119
+
120
+ if key_states.shape[-2] % self.residual_length != 0:
121
+ if key_states.shape[-2] < self.residual_length:
122
+ key_states_quant = None
123
+ key_states_full = key_states
124
+ value_states_quant = None
125
+ value_states_full = value_states
126
+ else:
127
+ key_states_quant = key_states[:, :, :-(key_states.shape[-2] % self.residual_length), :].contiguous()
128
+ key_states_full = key_states[:, :, -(key_states.shape[-2] % self.residual_length):, :].contiguous()
129
+ value_states_quant = value_states[:, :, :-(value_states.shape[-2] % self.residual_length), :].contiguous()
130
+ value_states_full = value_states[:, :, -(value_states.shape[-2] % self.residual_length):, :].contiguous()
131
+ else:
132
+ key_states_quant = key_states
133
+ key_states_full = None
134
+ value_states_quant = value_states
135
+ value_states_full = None
136
+ if key_states_quant is not None:
137
+ self._quantized_key_cache.append(self.squat_quantize_key(key_states_quant, self.squat_q_group_size, Ainv_t, P_inv))
138
+ self._quantized_value_cache.append(self._quantize(value_states_quant, axis=self.axis_value))
139
+ else:
140
+ self._quantized_key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
141
+ self._quantized_value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
142
+ if key_states_full is not None:
143
+ self.key_cache.append(key_states_full)
144
+ self.value_cache.append(value_states_full)
145
+ else:
146
+ self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
147
+ self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
148
+
149
+ keys_to_return, values_to_return = key_states, value_states
150
+
151
+ else: # decoding
152
+ if len(self._quantized_key_cache[layer_idx]) == 0:
153
+ dequant_key = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
154
+ else:
155
+ dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
156
+ if len(self._quantized_value_cache[layer_idx]) == 0:
157
+ dequant_value = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
158
+ else:
159
+ dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
160
+ keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
161
+ values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
162
+
163
+ keys_to_return = torch.cat(keys_to_return, dim=-2)
164
+ values_to_return = torch.cat(values_to_return, dim=-2)
165
+ if (
166
+ self.key_cache[layer_idx].dim() == 4
167
+ and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
168
+ ):
169
+ keys_to_quantize = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
170
+ quantized_key = self.squat_quantize_key(
171
+ keys_to_quantize, self.squat_q_group_size, self.auxiliary_matrices_A[layer_idx],
172
+ self.auxiliary_matrices_P[layer_idx]
173
+ )
174
+ self._quantized_key_cache[layer_idx] = self._quantize(
175
+ torch.cat([dequant_key, self._dequantize(quantized_key)], dim=2), axis=self.axis_key
176
+ )
177
+ self._quantized_value_cache[layer_idx] = self._quantize(
178
+ values_to_return.contiguous(), axis=self.axis_value
179
+ )
180
+ self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
181
+ self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
182
+ else:
183
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
184
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
185
+
186
+ return keys_to_return, values_to_return
187
+
188
+ def _get_query_subspace(self, key_states, query_states, attention_mask=None):
189
+ bsz = query_states.shape[0]
190
+ kv_nh = key_states.shape[1]
191
+ head_dim = query_states.shape[3]
192
+ num_key_value_groups = query_states.shape[1] // key_states.shape[1]
193
+ subspace_dim = min(self.squat_subspace_dim, num_key_value_groups*key_states.shape[2])
194
+
195
+ # Get valid tokens from attention mask
196
+ if attention_mask is not None:
197
+ if attention_mask.shape[2] == attention_mask.shape[3]-1:
198
+ attention_mask = attention_mask[:,:,:,:attention_mask.shape[2]]
199
+ # Get last row of attention mask [bs, 1, seq_len]
200
+ last_row_mask = attention_mask[:, :, -1, :]
201
+ # Find valid token positions (where mask is 0)
202
+ valid_tokens = (last_row_mask == 0).squeeze(1) # [bs, seq_len]
203
+
204
+ # Only keep valid tokens for each batch
205
+ query_subspace = []
206
+ for b in range(bsz):
207
+ # Get valid tokens for this batch
208
+ batch_valid = valid_tokens[b] # [seq_len]
209
+ # Select valid tokens from query states
210
+ batch_query = query_states[b] # [kv_nh, seq_len, head_dim]
211
+ batch_valid_query = batch_query[:, batch_valid, :] # [kv_nh, valid_len, head_dim]
212
+
213
+ valid_query_states_matrix = batch_valid_query.reshape(kv_nh, -1, head_dim)
214
+ U, S, Vh = torch.linalg.svd(valid_query_states_matrix.float(), full_matrices=False)
215
+ S_subspace = torch.diag_embed(S[:, :subspace_dim]).to(valid_query_states_matrix.dtype)
216
+ Vh_subspace = Vh[:, :subspace_dim, :].to(valid_query_states_matrix.dtype)
217
+ batch_query_subspace = torch.matmul(S_subspace, Vh_subspace)
218
+
219
+ query_subspace.append(batch_query_subspace)
220
+ if self.squat_shared_svd:
221
+ break
222
+
223
+ # Stack back into tensor
224
+ query_subspace = torch.stack(query_subspace) # [bs, kv_nh, valid_len, head_dim]
225
+ else:
226
+ query_states_matrix = query_states.reshape(bsz, kv_nh, -1, head_dim)
227
+ U, S, Vh = torch.linalg.svd(query_states_matrix.float(), full_matrices=False) #!!! float here might be suboptimal
228
+ S_subspace = torch.diag_embed(S[:, :, :subspace_dim]).to(query_states_matrix.dtype)
229
+ Vh_subspace = Vh[:, :, :subspace_dim, :].to(query_states_matrix.dtype)
230
+
231
+ # dimension: [bs, nh, subspace_dim, head_dim]
232
+ query_subspace = torch.matmul(S_subspace, Vh_subspace)
233
+
234
+ if self.squat_shared_svd:
235
+ query_subspace = query_subspace[0:1, ...]
236
+
237
+ # Ainv_t is a list of matrices
238
+ Ainv_t = self._generate_At_inv(self.squat_q_group_size, query_subspace.float(), lamb=self.squat_lambda)
239
+ P_inv = torch.inverse(Ainv_t[-1])
240
+
241
+ return Ainv_t, P_inv
242
+
243
+ def _generate_At_inv(self, quant_group_size, my_Qhat, lamb=1, tol=1e-7):
244
+ """
245
+ Generate a list of T matrices where the t-th matrix has dimension (t*g, t*g).
246
+
247
+ Parameters:
248
+ - quant_group_size (int): Factor for matrix dimension scaling
249
+ - lamb (float): Scaling factor for the final term
250
+ - my_Qhat (torch.Tensor): A matrix of size (d, d)
251
+
252
+ Returns:
253
+ - List[torch.Tensor]: List of int(head_dim/quant_group_size) matrices
254
+ """
255
+
256
+ bs, kv_nh, subspace_dim, head_dim = my_Qhat.shape
257
+ T = (head_dim+quant_group_size-1)//quant_group_size
258
+ matrices = [None] * T
259
+ device = my_Qhat.device
260
+ I = torch.eye(head_dim, device=device)
261
+ # Initialize A_T
262
+ A_T = I.expand(bs, kv_nh, head_dim, head_dim) + lamb * torch.matmul(
263
+ my_Qhat.transpose(-1, -2), my_Qhat
264
+ )
265
+ matrices[T - 1] = A_T
266
+
267
+ for t in range(T - 1, 0, -1): # Recursive computation of A_{t} from A_{t+1}
268
+ current_dim = t * quant_group_size
269
+
270
+ # Extract M_{t+1}, N_{t+1}, and O_{t+1}
271
+ M_t1 = A_T[:, :, :current_dim, :current_dim] # Top-left square matrix
272
+ N_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, :current_dim] # Bottom-left matrix
273
+ O_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, current_dim : current_dim + quant_group_size] # Bottom-right square matrix
274
+
275
+ # Compute A_t
276
+ I_mat = torch.eye(quant_group_size, device=device)
277
+ O_t1_inv = torch.inverse(O_t1 + tol * I_mat.expand(bs, kv_nh, quant_group_size, quant_group_size))
278
+ A_t = M_t1 - torch.matmul(N_t1.transpose(-1, -2), torch.matmul(O_t1_inv, N_t1))
279
+ matrices[t - 1] = A_t[:, :, :, -quant_group_size:]
280
+
281
+ # Update A_T for the next iteration
282
+ A_T = A_t
283
+ return matrices
284
+
285
+ def squat_quantize_key(self, key_states, quant_group_size, Ainv_t, P_inv):
286
+
287
+ bsz, nh, seq_len, hidden_dim = key_states.shape
288
+ dtype = key_states.dtype
289
+ T = (hidden_dim+quant_group_size-1)//quant_group_size
290
+ key_states_dequant = []
291
+ group = key_states # Extract the group
292
+ for i in range(T):
293
+ key_states_quant_this_quant_group = self._quantize(
294
+ group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size].contiguous(),
295
+ axis=self.axis_key
296
+ )
297
+ dequantized = self._dequantize(key_states_quant_this_quant_group)
298
+
299
+ if i < T - 1:
300
+ d_vec = (
301
+ dequantized
302
+ - group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size]
303
+ ).float()
304
+ H_t = Ainv_t[i]
305
+ B_t = P_inv[
306
+ :, :, (i + 1) * quant_group_size :, : (i + 1) * quant_group_size
307
+ ]
308
+ update = torch.matmul(
309
+ torch.matmul(d_vec, H_t.transpose(-2, -1)), B_t.transpose(-2, -1)
310
+ )
311
+ group[:, :, :, (i + 1) * quant_group_size :] = (
312
+ group[:, :, :, (i + 1) * quant_group_size :] + update
313
+ )
314
+
315
+ key_states_dequant.append(dequantized)
316
+
317
+ key_states_dequant = torch.cat(key_states_dequant, dim=3)
318
+ key_states_quant = self._quantize(key_states_dequant, axis=self.axis_key)
319
+ return key_states_quant
320
+
321
+
322
+ class QuantoSQuatCache(SQuatCache):
323
+
324
+ def __init__(self, cache_config: CacheConfig) -> None:
325
+ super().__init__(cache_config)
326
+
327
+ if is_optimum_quanto_available():
328
+ optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto"))
329
+ if optimum_quanto_version <= version.parse("0.2.5"):
330
+ raise ImportError(
331
+ f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
332
+ )
333
+ from optimum.quanto import MaxOptimizer, qint2, qint4
334
+
335
+ if self.nbits not in [2, 4]:
336
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
337
+
338
+ if self.axis_key not in [0, -1]:
339
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
340
+
341
+ if self.axis_value not in [0, -1]:
342
+ raise ValueError(
343
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
344
+ )
345
+
346
+ self.qtype = qint4 if self.nbits == 4 else qint2
347
+ self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
348
+
349
+ def _quantize(self, tensor, axis):
350
+ # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
351
+ if is_optimum_quanto_available():
352
+ from optimum.quanto import quantize_weight
353
+
354
+ scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
355
+ qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
356
+ return qtensor
357
+
358
+ def _dequantize(self, qtensor):
359
+ return qtensor.dequantize()
360
+
361
+
362
+ class HQQSQuatCache(SQuatCache):
363
+
364
+ def __init__(self, cache_config: CacheConfig) -> None:
365
+ super().__init__(cache_config)
366
+ if self.nbits not in [1, 2, 3, 4, 8]:
367
+ raise ValueError(
368
+ f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
369
+ )
370
+
371
+ if self.axis_key not in [0, 1]:
372
+ raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
373
+
374
+ if self.axis_value not in [0, 1]:
375
+ raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
376
+
377
+ self.quantizer = HQQQuantizer
378
+
379
+ def _quantize(self, tensor, axis):
380
+ qtensor, meta = self.quantizer.quantize(
381
+ tensor,
382
+ axis=axis,
383
+ device=self.device,
384
+ compute_dtype=self.compute_dtype,
385
+ nbits=self.nbits,
386
+ group_size=self.q_group_size,
387
+ )
388
+ meta["compute_dtype"] = self.compute_dtype
389
+ self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
390
+ meta["scale"] = meta["scale"].to(qtensor.device)
391
+ meta["zero"] = meta["zero"].to(qtensor.device)
392
+ return qtensor, meta
393
+
394
+ def _dequantize(self, qtensor):
395
+ quant_tensor, meta = qtensor
396
+ tensor = self.quantizer.dequantize(quant_tensor, meta)
397
+ return tensor
398
+
399
+
400
+ SQUAT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoSQuatCache, "HQQ": HQQSQuatCache}
401
+
402
+ def generate(model, generation_config=None, backend="quanto", nbits=2, quant_group_size=64, residual_length=32, squat_lambda=0.001, subspace_dim=20, shared_svd=True, **kwargs):
403
+ """Custom generate function for SinkCache.
404
+ Args:
405
+ model (`PreTrainedModel`):
406
+ The model to generate from.
407
+ """
408
+
409
+ cache_config = SQuatCacheConfig(
410
+ backend=backend,
411
+ nbits=nbits,
412
+ quant_group_size=quant_group_size,
413
+ residual_length=residual_length,
414
+ squat_lambda=squat_lambda,
415
+ subspace_dim=subspace_dim,
416
+ shared_svd=shared_svd,
417
+ )
418
+ cache_class = SQUAT_BACKEND_CLASSES_MAPPING[cache_config.backend]
419
+
420
+ if cache_config.backend == "quanto" and not is_optimum_quanto_available():
421
+ raise ImportError(
422
+ "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
423
+ "Please install it via with `pip install optimum-quanto`"
424
+ )
425
+ elif cache_config.backend == "HQQ" and not is_hqq_available():
426
+ raise ImportError(
427
+ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
428
+ "Please install it via with `pip install hqq`"
429
+ )
430
+
431
+ # 1.b. The model must be decoder-only
432
+ if model.config.is_encoder_decoder:
433
+ raise ValueError("This custom generate function only works with decoder-only models")
434
+
435
+ # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
436
+ # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
437
+ kwargs.pop("custom_generate", None)
438
+
439
+ # 2. Generate with SinkCache
440
+ # 2.a. prepare the cache, if it was not passed.
441
+ past_key_values = kwargs.pop("past_key_values", None)
442
+ if past_key_values is None:
443
+ past_key_values = cache_class(cache_config=cache_config)
444
+
445
+ # 2.b. generate with the cache
446
+ generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
447
+ return generation_outputs