FALcon6 commited on
Commit
756dc95
·
1 Parent(s): bb2ae14

Added CacheLayerMixin and DynamicLayer

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +141 -1
modeling_minicpm.py CHANGED
@@ -16,6 +16,7 @@
16
  import math
17
  import re
18
  import warnings
 
19
  from typing import Any, Dict, List, Optional, Tuple, Union
20
 
21
  import torch
@@ -24,7 +25,7 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
@@ -233,6 +234,145 @@ class CompressK(torch.nn.Module):
233
  return compressed_k, cu_seqlens_compressed
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  class InfLLMv2CacheLayer(DynamicLayer):
238
  def __init__(self):
 
16
  import math
17
  import re
18
  import warnings
19
+ from abc import ABC, abstractmethod
20
  from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
  import torch
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
  from transformers.modeling_attn_mask_utils import (
30
  AttentionMaskConverter,
31
  _prepare_4d_attention_mask,
 
234
  return compressed_k, cu_seqlens_compressed
235
 
236
 
237
+ class CacheLayerMixin(ABC):
238
+ """Base, abstract class for a single layer's cache."""
239
+
240
+ is_compileable = False
241
+
242
+ def __init__(self):
243
+ self.keys: torch.Tensor | None = None
244
+ self.values: torch.Tensor | None = None
245
+ self.is_initialized = False
246
+
247
+ def __repr__(self):
248
+ return f"{self.__class__.__name__}"
249
+
250
+ @abstractmethod
251
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
252
+
253
+ @abstractmethod
254
+ def update(
255
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
256
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
257
+
258
+ @abstractmethod
259
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
260
+
261
+ @abstractmethod
262
+ def get_seq_length(self) -> int: ...
263
+
264
+ @abstractmethod
265
+ def get_max_cache_shape(self) -> int: ...
266
+
267
+ def offload(self):
268
+ """Offload this layer's data to CPU device."""
269
+ if self.is_initialized:
270
+ self.keys = self.keys.to("cpu", non_blocking=True)
271
+ self.values = self.values.to("cpu", non_blocking=True)
272
+
273
+ def prefetch(self):
274
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
275
+ if self.is_initialized and self.keys.device != self.device:
276
+ self.keys = self.keys.to(self.device, non_blocking=True)
277
+ self.values = self.values.to(self.device, non_blocking=True)
278
+
279
+ def reset(self) -> None:
280
+ """Resets the cache values while preserving the objects"""
281
+ if self.is_initialized:
282
+ self.keys.zero_()
283
+ self.values.zero_()
284
+ # This attribute is set on several Layers
285
+ if hasattr(self, "cumulative_length"):
286
+ # It can either be an int for dynamic layers, or a tensor for static layers
287
+ if isinstance(self.cumulative_length, int):
288
+ self.cumulative_length = 0
289
+ else:
290
+ self.cumulative_length.zero_()
291
+
292
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
293
+ """Reorders this layer's cache for beam search."""
294
+ if self.get_seq_length() > 0:
295
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
296
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
297
+
298
+
299
+ class DynamicLayer(CacheLayerMixin):
300
+ """
301
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
302
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
303
+ """
304
+
305
+ is_sliding = False
306
+
307
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
308
+ self.dtype, self.device = key_states.dtype, key_states.device
309
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
310
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
311
+ self.is_initialized = True
312
+
313
+ def update(
314
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
315
+ ) -> tuple[torch.Tensor, torch.Tensor]:
316
+ """
317
+ Update the key and value caches in-place, and return the necessary keys and value states.
318
+
319
+ Args:
320
+ key_states (`torch.Tensor`): The new key states to cache.
321
+ value_states (`torch.Tensor`): The new value states to cache.
322
+
323
+ Returns:
324
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
325
+ """
326
+ # Lazy initialization
327
+ if not self.is_initialized:
328
+ self.lazy_initialization(key_states, value_states)
329
+
330
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
331
+ self.values = torch.cat([self.values, value_states], dim=-2)
332
+ return self.keys, self.values
333
+
334
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
335
+ """Return the length and offset of the cache, used to generate the mask"""
336
+ kv_offset = 0
337
+ kv_length = self.get_seq_length() + query_length
338
+ return kv_length, kv_offset
339
+
340
+ def get_seq_length(self) -> int:
341
+ """Returns the sequence length of the cached states."""
342
+ if not self.is_initialized or self.keys.numel() == 0:
343
+ return 0
344
+ return self.keys.shape[-2]
345
+
346
+ def get_max_cache_shape(self) -> int:
347
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
348
+ return -1
349
+
350
+ def crop(self, max_length: int) -> None:
351
+ """
352
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
353
+ to remove `max_length` tokens.
354
+ """
355
+ if max_length < 0:
356
+ max_length = self.get_seq_length() - abs(max_length)
357
+
358
+ if self.get_seq_length() <= max_length:
359
+ return
360
+
361
+ self.keys = self.keys[..., :max_length, :]
362
+ self.values = self.values[..., :max_length, :]
363
+
364
+ def batch_repeat_interleave(self, repeats: int) -> None:
365
+ """Repeat the cache `repeats` times in the batch dimension."""
366
+ if self.get_seq_length() > 0:
367
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
368
+ self.values = self.values.repeat_interleave(repeats, dim=0)
369
+
370
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
371
+ """Only keep the `indices` in the batch dimension of the cache."""
372
+ if self.get_seq_length() > 0:
373
+ self.keys = self.keys[indices, ...]
374
+ self.values = self.values[indices, ...]
375
+
376
 
377
  class InfLLMv2CacheLayer(DynamicLayer):
378
  def __init__(self):