Fix: resolve transformers version compatibility for DynamicLayer and cache initialization

#18
by FALcon6 - opened
Files changed (1) hide show
  1. modeling_minicpm.py +153 -10
modeling_minicpm.py CHANGED
@@ -16,6 +16,7 @@
16
  import math
17
  import re
18
  import warnings
 
19
  from typing import Any, Dict, List, Optional, Tuple, Union
20
 
21
  import torch
@@ -24,7 +25,7 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
@@ -233,6 +234,145 @@ class CompressK(torch.nn.Module):
233
  return compressed_k, cu_seqlens_compressed
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  class InfLLMv2CacheLayer(DynamicLayer):
238
  def __init__(self):
@@ -1814,18 +1954,21 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1814
  past_key_values_length = 0
1815
 
1816
  if use_cache:
1817
- use_legacy_cache = not isinstance(past_key_values, Cache)
1818
- if use_legacy_cache:
1819
  raise ValueError(
1820
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1821
  )
1822
-
 
 
 
 
 
 
 
1823
  # Calculate the usable length of past key values
1824
- past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, InfLLMv2Cache) else 0
1825
-
1826
- # Initialize InfLLMv2Cache if needed
1827
- if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
1828
- past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
1829
 
1830
  if position_ids is None:
1831
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1907,7 +2050,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1907
 
1908
  next_cache = None
1909
  if use_cache:
1910
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1911
  if not return_dict:
1912
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1913
  return BaseModelOutputWithPast(
 
16
  import math
17
  import re
18
  import warnings
19
+ from abc import ABC, abstractmethod
20
  from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
  import torch
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
  from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
  from transformers.modeling_attn_mask_utils import (
30
  AttentionMaskConverter,
31
  _prepare_4d_attention_mask,
 
234
  return compressed_k, cu_seqlens_compressed
235
 
236
 
237
+ class CacheLayerMixin(ABC):
238
+ """Base, abstract class for a single layer's cache."""
239
+
240
+ is_compileable = False
241
+
242
+ def __init__(self):
243
+ self.keys: torch.Tensor | None = None
244
+ self.values: torch.Tensor | None = None
245
+ self.is_initialized = False
246
+
247
+ def __repr__(self):
248
+ return f"{self.__class__.__name__}"
249
+
250
+ @abstractmethod
251
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
252
+
253
+ @abstractmethod
254
+ def update(
255
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
256
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
257
+
258
+ @abstractmethod
259
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
260
+
261
+ @abstractmethod
262
+ def get_seq_length(self) -> int: ...
263
+
264
+ @abstractmethod
265
+ def get_max_cache_shape(self) -> int: ...
266
+
267
+ def offload(self):
268
+ """Offload this layer's data to CPU device."""
269
+ if self.is_initialized:
270
+ self.keys = self.keys.to("cpu", non_blocking=True)
271
+ self.values = self.values.to("cpu", non_blocking=True)
272
+
273
+ def prefetch(self):
274
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
275
+ if self.is_initialized and self.keys.device != self.device:
276
+ self.keys = self.keys.to(self.device, non_blocking=True)
277
+ self.values = self.values.to(self.device, non_blocking=True)
278
+
279
+ def reset(self) -> None:
280
+ """Resets the cache values while preserving the objects"""
281
+ if self.is_initialized:
282
+ self.keys.zero_()
283
+ self.values.zero_()
284
+ # This attribute is set on several Layers
285
+ if hasattr(self, "cumulative_length"):
286
+ # It can either be an int for dynamic layers, or a tensor for static layers
287
+ if isinstance(self.cumulative_length, int):
288
+ self.cumulative_length = 0
289
+ else:
290
+ self.cumulative_length.zero_()
291
+
292
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
293
+ """Reorders this layer's cache for beam search."""
294
+ if self.get_seq_length() > 0:
295
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
296
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
297
+
298
+
299
+ class DynamicLayer(CacheLayerMixin):
300
+ """
301
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
302
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
303
+ """
304
+
305
+ is_sliding = False
306
+
307
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
308
+ self.dtype, self.device = key_states.dtype, key_states.device
309
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
310
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
311
+ self.is_initialized = True
312
+
313
+ def update(
314
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
315
+ ) -> tuple[torch.Tensor, torch.Tensor]:
316
+ """
317
+ Update the key and value caches in-place, and return the necessary keys and value states.
318
+
319
+ Args:
320
+ key_states (`torch.Tensor`): The new key states to cache.
321
+ value_states (`torch.Tensor`): The new value states to cache.
322
+
323
+ Returns:
324
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
325
+ """
326
+ # Lazy initialization
327
+ if not self.is_initialized:
328
+ self.lazy_initialization(key_states, value_states)
329
+
330
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
331
+ self.values = torch.cat([self.values, value_states], dim=-2)
332
+ return self.keys, self.values
333
+
334
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
335
+ """Return the length and offset of the cache, used to generate the mask"""
336
+ kv_offset = 0
337
+ kv_length = self.get_seq_length() + query_length
338
+ return kv_length, kv_offset
339
+
340
+ def get_seq_length(self) -> int:
341
+ """Returns the sequence length of the cached states."""
342
+ if not self.is_initialized or self.keys.numel() == 0:
343
+ return 0
344
+ return self.keys.shape[-2]
345
+
346
+ def get_max_cache_shape(self) -> int:
347
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
348
+ return -1
349
+
350
+ def crop(self, max_length: int) -> None:
351
+ """
352
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
353
+ to remove `max_length` tokens.
354
+ """
355
+ if max_length < 0:
356
+ max_length = self.get_seq_length() - abs(max_length)
357
+
358
+ if self.get_seq_length() <= max_length:
359
+ return
360
+
361
+ self.keys = self.keys[..., :max_length, :]
362
+ self.values = self.values[..., :max_length, :]
363
+
364
+ def batch_repeat_interleave(self, repeats: int) -> None:
365
+ """Repeat the cache `repeats` times in the batch dimension."""
366
+ if self.get_seq_length() > 0:
367
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
368
+ self.values = self.values.repeat_interleave(repeats, dim=0)
369
+
370
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
371
+ """Only keep the `indices` in the batch dimension of the cache."""
372
+ if self.get_seq_length() > 0:
373
+ self.keys = self.keys[indices, ...]
374
+ self.values = self.values[indices, ...]
375
+
376
 
377
  class InfLLMv2CacheLayer(DynamicLayer):
378
  def __init__(self):
 
1954
  past_key_values_length = 0
1955
 
1956
  if use_cache:
1957
+ # Reject old tuple-style cache, but allow None (first forward pass)
1958
+ if past_key_values is not None and not isinstance(past_key_values, Cache):
1959
  raise ValueError(
1960
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1961
  )
1962
+
1963
+ # Initialize cache if None (first forward pass)
1964
+ if past_key_values is None:
1965
+ if self.config.sparse_config is not None and torch.cuda.is_available():
1966
+ past_key_values = InfLLMv2Cache(config=self.config, num_hidden_layers=self.config.num_hidden_layers)
1967
+ else:
1968
+ past_key_values = DynamicCache()
1969
+
1970
  # Calculate the usable length of past key values
1971
+ past_key_values_length = past_key_values.get_seq_length()
 
 
 
 
1972
 
1973
  if position_ids is None:
1974
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
2050
 
2051
  next_cache = None
2052
  if use_cache:
2053
+ next_cache = next_decoder_cache
2054
  if not return_dict:
2055
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
2056
  return BaseModelOutputWithPast(