Fix: resolve transformers version compatibility for DynamicLayer and cache initialization
#18
by FALcon6 - opened
- modeling_minicpm.py +153 -10
modeling_minicpm.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
import math
|
| 17 |
import re
|
| 18 |
import warnings
|
|
|
|
| 19 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
|
@@ -24,7 +25,7 @@ import torch.utils.checkpoint
|
|
| 24 |
from torch import nn
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.modeling_attn_mask_utils import (
|
| 29 |
AttentionMaskConverter,
|
| 30 |
_prepare_4d_attention_mask,
|
|
@@ -233,6 +234,145 @@ class CompressK(torch.nn.Module):
|
|
| 233 |
return compressed_k, cu_seqlens_compressed
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
class InfLLMv2CacheLayer(DynamicLayer):
|
| 238 |
def __init__(self):
|
|
@@ -1814,18 +1954,21 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1814 |
past_key_values_length = 0
|
| 1815 |
|
| 1816 |
if use_cache:
|
| 1817 |
-
|
| 1818 |
-
if
|
| 1819 |
raise ValueError(
|
| 1820 |
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
|
| 1821 |
)
|
| 1822 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1823 |
# Calculate the usable length of past key values
|
| 1824 |
-
past_key_values_length = past_key_values.get_seq_length()
|
| 1825 |
-
|
| 1826 |
-
# Initialize InfLLMv2Cache if needed
|
| 1827 |
-
if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
|
| 1828 |
-
past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
|
| 1829 |
|
| 1830 |
if position_ids is None:
|
| 1831 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
@@ -1907,7 +2050,7 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1907 |
|
| 1908 |
next_cache = None
|
| 1909 |
if use_cache:
|
| 1910 |
-
next_cache = next_decoder_cache
|
| 1911 |
if not return_dict:
|
| 1912 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1913 |
return BaseModelOutputWithPast(
|
|
|
|
| 16 |
import math
|
| 17 |
import re
|
| 18 |
import warnings
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import torch
|
|
|
|
| 25 |
from torch import nn
|
| 26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
from transformers.modeling_attn_mask_utils import (
|
| 30 |
AttentionMaskConverter,
|
| 31 |
_prepare_4d_attention_mask,
|
|
|
|
| 234 |
return compressed_k, cu_seqlens_compressed
|
| 235 |
|
| 236 |
|
| 237 |
+
class CacheLayerMixin(ABC):
|
| 238 |
+
"""Base, abstract class for a single layer's cache."""
|
| 239 |
+
|
| 240 |
+
is_compileable = False
|
| 241 |
+
|
| 242 |
+
def __init__(self):
|
| 243 |
+
self.keys: torch.Tensor | None = None
|
| 244 |
+
self.values: torch.Tensor | None = None
|
| 245 |
+
self.is_initialized = False
|
| 246 |
+
|
| 247 |
+
def __repr__(self):
|
| 248 |
+
return f"{self.__class__.__name__}"
|
| 249 |
+
|
| 250 |
+
@abstractmethod
|
| 251 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
def update(
|
| 255 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 256 |
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
| 257 |
+
|
| 258 |
+
@abstractmethod
|
| 259 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
|
| 260 |
+
|
| 261 |
+
@abstractmethod
|
| 262 |
+
def get_seq_length(self) -> int: ...
|
| 263 |
+
|
| 264 |
+
@abstractmethod
|
| 265 |
+
def get_max_cache_shape(self) -> int: ...
|
| 266 |
+
|
| 267 |
+
def offload(self):
|
| 268 |
+
"""Offload this layer's data to CPU device."""
|
| 269 |
+
if self.is_initialized:
|
| 270 |
+
self.keys = self.keys.to("cpu", non_blocking=True)
|
| 271 |
+
self.values = self.values.to("cpu", non_blocking=True)
|
| 272 |
+
|
| 273 |
+
def prefetch(self):
|
| 274 |
+
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
|
| 275 |
+
if self.is_initialized and self.keys.device != self.device:
|
| 276 |
+
self.keys = self.keys.to(self.device, non_blocking=True)
|
| 277 |
+
self.values = self.values.to(self.device, non_blocking=True)
|
| 278 |
+
|
| 279 |
+
def reset(self) -> None:
|
| 280 |
+
"""Resets the cache values while preserving the objects"""
|
| 281 |
+
if self.is_initialized:
|
| 282 |
+
self.keys.zero_()
|
| 283 |
+
self.values.zero_()
|
| 284 |
+
# This attribute is set on several Layers
|
| 285 |
+
if hasattr(self, "cumulative_length"):
|
| 286 |
+
# It can either be an int for dynamic layers, or a tensor for static layers
|
| 287 |
+
if isinstance(self.cumulative_length, int):
|
| 288 |
+
self.cumulative_length = 0
|
| 289 |
+
else:
|
| 290 |
+
self.cumulative_length.zero_()
|
| 291 |
+
|
| 292 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 293 |
+
"""Reorders this layer's cache for beam search."""
|
| 294 |
+
if self.get_seq_length() > 0:
|
| 295 |
+
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
|
| 296 |
+
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class DynamicLayer(CacheLayerMixin):
|
| 300 |
+
"""
|
| 301 |
+
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 302 |
+
It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
is_sliding = False
|
| 306 |
+
|
| 307 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 308 |
+
self.dtype, self.device = key_states.dtype, key_states.device
|
| 309 |
+
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 310 |
+
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 311 |
+
self.is_initialized = True
|
| 312 |
+
|
| 313 |
+
def update(
|
| 314 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 315 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 316 |
+
"""
|
| 317 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 321 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 325 |
+
"""
|
| 326 |
+
# Lazy initialization
|
| 327 |
+
if not self.is_initialized:
|
| 328 |
+
self.lazy_initialization(key_states, value_states)
|
| 329 |
+
|
| 330 |
+
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
| 331 |
+
self.values = torch.cat([self.values, value_states], dim=-2)
|
| 332 |
+
return self.keys, self.values
|
| 333 |
+
|
| 334 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 335 |
+
"""Return the length and offset of the cache, used to generate the mask"""
|
| 336 |
+
kv_offset = 0
|
| 337 |
+
kv_length = self.get_seq_length() + query_length
|
| 338 |
+
return kv_length, kv_offset
|
| 339 |
+
|
| 340 |
+
def get_seq_length(self) -> int:
|
| 341 |
+
"""Returns the sequence length of the cached states."""
|
| 342 |
+
if not self.is_initialized or self.keys.numel() == 0:
|
| 343 |
+
return 0
|
| 344 |
+
return self.keys.shape[-2]
|
| 345 |
+
|
| 346 |
+
def get_max_cache_shape(self) -> int:
|
| 347 |
+
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
|
| 348 |
+
return -1
|
| 349 |
+
|
| 350 |
+
def crop(self, max_length: int) -> None:
|
| 351 |
+
"""
|
| 352 |
+
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
|
| 353 |
+
to remove `max_length` tokens.
|
| 354 |
+
"""
|
| 355 |
+
if max_length < 0:
|
| 356 |
+
max_length = self.get_seq_length() - abs(max_length)
|
| 357 |
+
|
| 358 |
+
if self.get_seq_length() <= max_length:
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
self.keys = self.keys[..., :max_length, :]
|
| 362 |
+
self.values = self.values[..., :max_length, :]
|
| 363 |
+
|
| 364 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 365 |
+
"""Repeat the cache `repeats` times in the batch dimension."""
|
| 366 |
+
if self.get_seq_length() > 0:
|
| 367 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 368 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 369 |
+
|
| 370 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 371 |
+
"""Only keep the `indices` in the batch dimension of the cache."""
|
| 372 |
+
if self.get_seq_length() > 0:
|
| 373 |
+
self.keys = self.keys[indices, ...]
|
| 374 |
+
self.values = self.values[indices, ...]
|
| 375 |
+
|
| 376 |
|
| 377 |
class InfLLMv2CacheLayer(DynamicLayer):
|
| 378 |
def __init__(self):
|
|
|
|
| 1954 |
past_key_values_length = 0
|
| 1955 |
|
| 1956 |
if use_cache:
|
| 1957 |
+
# Reject old tuple-style cache, but allow None (first forward pass)
|
| 1958 |
+
if past_key_values is not None and not isinstance(past_key_values, Cache):
|
| 1959 |
raise ValueError(
|
| 1960 |
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
|
| 1961 |
)
|
| 1962 |
+
|
| 1963 |
+
# Initialize cache if None (first forward pass)
|
| 1964 |
+
if past_key_values is None:
|
| 1965 |
+
if self.config.sparse_config is not None and torch.cuda.is_available():
|
| 1966 |
+
past_key_values = InfLLMv2Cache(config=self.config, num_hidden_layers=self.config.num_hidden_layers)
|
| 1967 |
+
else:
|
| 1968 |
+
past_key_values = DynamicCache()
|
| 1969 |
+
|
| 1970 |
# Calculate the usable length of past key values
|
| 1971 |
+
past_key_values_length = past_key_values.get_seq_length()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1972 |
|
| 1973 |
if position_ids is None:
|
| 1974 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
| 2050 |
|
| 2051 |
next_cache = None
|
| 2052 |
if use_cache:
|
| 2053 |
+
next_cache = next_decoder_cache
|
| 2054 |
if not return_dict:
|
| 2055 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 2056 |
return BaseModelOutputWithPast(
|