Added CacheLayerMixin and DynamicLayer
Browse files- modeling_minicpm.py +141 -1
modeling_minicpm.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
import math
|
| 17 |
import re
|
| 18 |
import warnings
|
|
|
|
| 19 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 20 |
|
| 21 |
import torch
|
|
@@ -24,7 +25,7 @@ import torch.utils.checkpoint
|
|
| 24 |
from torch import nn
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.modeling_attn_mask_utils import (
|
| 29 |
AttentionMaskConverter,
|
| 30 |
_prepare_4d_attention_mask,
|
|
@@ -233,6 +234,145 @@ class CompressK(torch.nn.Module):
|
|
| 233 |
return compressed_k, cu_seqlens_compressed
|
| 234 |
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
class InfLLMv2CacheLayer(DynamicLayer):
|
| 238 |
def __init__(self):
|
|
|
|
| 16 |
import math
|
| 17 |
import re
|
| 18 |
import warnings
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import torch
|
|
|
|
| 25 |
from torch import nn
|
| 26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
from transformers.modeling_attn_mask_utils import (
|
| 30 |
AttentionMaskConverter,
|
| 31 |
_prepare_4d_attention_mask,
|
|
|
|
| 234 |
return compressed_k, cu_seqlens_compressed
|
| 235 |
|
| 236 |
|
| 237 |
+
class CacheLayerMixin(ABC):
|
| 238 |
+
"""Base, abstract class for a single layer's cache."""
|
| 239 |
+
|
| 240 |
+
is_compileable = False
|
| 241 |
+
|
| 242 |
+
def __init__(self):
|
| 243 |
+
self.keys: torch.Tensor | None = None
|
| 244 |
+
self.values: torch.Tensor | None = None
|
| 245 |
+
self.is_initialized = False
|
| 246 |
+
|
| 247 |
+
def __repr__(self):
|
| 248 |
+
return f"{self.__class__.__name__}"
|
| 249 |
+
|
| 250 |
+
@abstractmethod
|
| 251 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
|
| 252 |
+
|
| 253 |
+
@abstractmethod
|
| 254 |
+
def update(
|
| 255 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 256 |
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
| 257 |
+
|
| 258 |
+
@abstractmethod
|
| 259 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
|
| 260 |
+
|
| 261 |
+
@abstractmethod
|
| 262 |
+
def get_seq_length(self) -> int: ...
|
| 263 |
+
|
| 264 |
+
@abstractmethod
|
| 265 |
+
def get_max_cache_shape(self) -> int: ...
|
| 266 |
+
|
| 267 |
+
def offload(self):
|
| 268 |
+
"""Offload this layer's data to CPU device."""
|
| 269 |
+
if self.is_initialized:
|
| 270 |
+
self.keys = self.keys.to("cpu", non_blocking=True)
|
| 271 |
+
self.values = self.values.to("cpu", non_blocking=True)
|
| 272 |
+
|
| 273 |
+
def prefetch(self):
|
| 274 |
+
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
|
| 275 |
+
if self.is_initialized and self.keys.device != self.device:
|
| 276 |
+
self.keys = self.keys.to(self.device, non_blocking=True)
|
| 277 |
+
self.values = self.values.to(self.device, non_blocking=True)
|
| 278 |
+
|
| 279 |
+
def reset(self) -> None:
|
| 280 |
+
"""Resets the cache values while preserving the objects"""
|
| 281 |
+
if self.is_initialized:
|
| 282 |
+
self.keys.zero_()
|
| 283 |
+
self.values.zero_()
|
| 284 |
+
# This attribute is set on several Layers
|
| 285 |
+
if hasattr(self, "cumulative_length"):
|
| 286 |
+
# It can either be an int for dynamic layers, or a tensor for static layers
|
| 287 |
+
if isinstance(self.cumulative_length, int):
|
| 288 |
+
self.cumulative_length = 0
|
| 289 |
+
else:
|
| 290 |
+
self.cumulative_length.zero_()
|
| 291 |
+
|
| 292 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 293 |
+
"""Reorders this layer's cache for beam search."""
|
| 294 |
+
if self.get_seq_length() > 0:
|
| 295 |
+
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
|
| 296 |
+
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class DynamicLayer(CacheLayerMixin):
|
| 300 |
+
"""
|
| 301 |
+
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 302 |
+
It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
is_sliding = False
|
| 306 |
+
|
| 307 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 308 |
+
self.dtype, self.device = key_states.dtype, key_states.device
|
| 309 |
+
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 310 |
+
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 311 |
+
self.is_initialized = True
|
| 312 |
+
|
| 313 |
+
def update(
|
| 314 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 315 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 316 |
+
"""
|
| 317 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 321 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 325 |
+
"""
|
| 326 |
+
# Lazy initialization
|
| 327 |
+
if not self.is_initialized:
|
| 328 |
+
self.lazy_initialization(key_states, value_states)
|
| 329 |
+
|
| 330 |
+
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
| 331 |
+
self.values = torch.cat([self.values, value_states], dim=-2)
|
| 332 |
+
return self.keys, self.values
|
| 333 |
+
|
| 334 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 335 |
+
"""Return the length and offset of the cache, used to generate the mask"""
|
| 336 |
+
kv_offset = 0
|
| 337 |
+
kv_length = self.get_seq_length() + query_length
|
| 338 |
+
return kv_length, kv_offset
|
| 339 |
+
|
| 340 |
+
def get_seq_length(self) -> int:
|
| 341 |
+
"""Returns the sequence length of the cached states."""
|
| 342 |
+
if not self.is_initialized or self.keys.numel() == 0:
|
| 343 |
+
return 0
|
| 344 |
+
return self.keys.shape[-2]
|
| 345 |
+
|
| 346 |
+
def get_max_cache_shape(self) -> int:
|
| 347 |
+
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
|
| 348 |
+
return -1
|
| 349 |
+
|
| 350 |
+
def crop(self, max_length: int) -> None:
|
| 351 |
+
"""
|
| 352 |
+
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
|
| 353 |
+
to remove `max_length` tokens.
|
| 354 |
+
"""
|
| 355 |
+
if max_length < 0:
|
| 356 |
+
max_length = self.get_seq_length() - abs(max_length)
|
| 357 |
+
|
| 358 |
+
if self.get_seq_length() <= max_length:
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
self.keys = self.keys[..., :max_length, :]
|
| 362 |
+
self.values = self.values[..., :max_length, :]
|
| 363 |
+
|
| 364 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 365 |
+
"""Repeat the cache `repeats` times in the batch dimension."""
|
| 366 |
+
if self.get_seq_length() > 0:
|
| 367 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 368 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 369 |
+
|
| 370 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 371 |
+
"""Only keep the `indices` in the batch dimension of the cache."""
|
| 372 |
+
if self.get_seq_length() > 0:
|
| 373 |
+
self.keys = self.keys[indices, ...]
|
| 374 |
+
self.values = self.values[indices, ...]
|
| 375 |
+
|
| 376 |
|
| 377 |
class InfLLMv2CacheLayer(DynamicLayer):
|
| 378 |
def __init__(self):
|