| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Any, Optional |
| |
|
| | import torch |
| | import transformers |
| |
|
| | __all__ = ["JetCache"] |
| |
|
| |
|
| | class JetNemotronCache(transformers.cache_utils.Cache): |
| |
|
| | def __init__( |
| | self, |
| | seen_tokens: int = 0 |
| | ) -> JetNemotronCache: |
| |
|
| | self.states: list[dict[str, Any]] = [] |
| | self.layer_wise_states: dict[str, Any] = {} |
| |
|
| | self._base_seen_tokens = seen_tokens |
| | self._seen_tokens = [] |
| |
|
| | def __getitem__(self, layer_idx: int) -> dict[str, Any]: |
| | if layer_idx < len(self): |
| | return self.states[layer_idx] |
| | else: |
| | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
| |
|
| | def __iter__(self): |
| | for state in self.states: |
| | yield state |
| |
|
| | def __len__(self): |
| | return len(self.states) |
| |
|
| | def update( |
| | self, |
| | recurrent_state: torch.Tensor = None, |
| | attn_state: tuple[torch.Tensor, torch.Tensor] = None, |
| | conv_state: tuple[torch.Tensor] = None, |
| | ffn_state: torch.Tensor = None, |
| | layer_idx: int = 0, |
| | offset: Optional[int] = 1, |
| | increase_seen_tokens: bool = True, |
| | cache_kwargs: dict[str, Any] = {}, |
| | ) -> dict[str, Any]: |
| | """ |
| | Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. |
| | |
| | Args: |
| | recurrent_state (`torch.Tensor`, `optional`): |
| | The new recurrent state to cache. |
| | attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): |
| | The new attention key/value states to cache. |
| | conv_state (`Tuple[torch.Tensor]`, `optional`): |
| | The new convolution state to cache. |
| | layer_idx (`int`, defaults to 0): |
| | The index of the layer to cache the states for. |
| | offset (`int`, `optional`, defaults to 1): |
| | The number of new tokens being processed. |
| | cache_kwargs (`Dict[str, Any]`, `optional`): |
| | Additional arguments for the cache subclass. |
| | |
| | Return: |
| | Dictionary of the updated state. |
| | """ |
| | if len(self._seen_tokens) <= layer_idx: |
| | self._seen_tokens.append(self._base_seen_tokens) |
| |
|
| | |
| | if increase_seen_tokens: |
| | self.increase_seen_tokens(layer_idx, offset) |
| | |
| | if attn_state is not None: |
| | input_size = attn_state[0].shape[-2] |
| | window_size = cache_kwargs.get('window_size', None) |
| | if not isinstance(attn_state, tuple) or len(attn_state) != 2: |
| | raise ValueError("`attn_state` must be a tuple of two tensors for key/value states") |
| | if len(self.states) <= layer_idx: |
| | |
| | state = dict( |
| | recurrent_state=recurrent_state, |
| | attn_state=attn_state, |
| | conv_state=conv_state, |
| | ffn_state=ffn_state |
| | ) |
| | if attn_state is not None and window_size is not None: |
| | |
| | |
| | _key_state = attn_state[0][..., -window_size:, :] |
| | _value_state = attn_state[1][..., -window_size:, :] |
| |
|
| | _attn_state = (_key_state, _value_state) |
| | _state = dict( |
| | recurrent_state=recurrent_state, |
| | attn_state=_attn_state, |
| | conv_state=conv_state, |
| | ffn_state=ffn_state |
| | ) |
| | self.states.append(_state) |
| | else: |
| | self.states.append(state) |
| | else: |
| | state = self.states[layer_idx] |
| | if recurrent_state is not None: |
| | state['recurrent_state'] = recurrent_state |
| | if attn_state is not None: |
| | key_state, value_state = state['attn_state'] |
| | assert window_size is None or key_state.shape[-2] <= window_size |
| | if window_size is not None and key_state.shape[-2] == window_size and input_size == 1: |
| | |
| | |
| | |
| | |
| | key_state = key_state.roll(-input_size, -2) |
| | value_state = value_state.roll(-input_size, -2) |
| | |
| | |
| | key_state[..., -input_size:, :] = attn_state[0] |
| | value_state[..., -input_size:, :] = attn_state[1] |
| | |
| | attn_state = (key_state, value_state) |
| | else: |
| | |
| | attn_state = (torch.cat([key_state, attn_state[0]], -2), |
| | torch.cat([value_state, attn_state[1]], -2),) |
| | state['attn_state'] = attn_state |
| | if conv_state is not None: |
| | state['conv_state'] = conv_state |
| | if ffn_state is not None: |
| | state['ffn_state'] = ffn_state |
| |
|
| | assert len(self.states) == len(self._seen_tokens) |
| |
|
| | return state |
| |
|
| | def trim_attn_state(self, layer_idx: int, window_size: int) -> None: |
| | |
| | |
| | assert layer_idx < len(self.states), f"Layer index {layer_idx} out of range for states with length {len(self.states)}" |
| | state = self.states[layer_idx] |
| | assert state["attn_state"] is not None, f"Layer {layer_idx} does not have an attention state" |
| | key_state, value_state = state["attn_state"] |
| | if key_state.shape[-2] > window_size: |
| | state["attn_state"] = ( |
| | key_state[..., -window_size:, :], |
| | value_state[..., -window_size:, :], |
| | ) |
| |
|
| | def increase_seen_tokens(self, layer_idx: int, offset: int = 1) -> None: |
| | """Increases the number of seen tokens for the layer `layer_idx` by `offset`.""" |
| | self._seen_tokens[layer_idx] += offset |
| |
|
| | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| | if len(self._seen_tokens) <= layer_idx: |
| | return self._base_seen_tokens |
| | return self._seen_tokens[layer_idx] |
| |
|
| | def get_max_length(self) -> Optional[int]: |
| | """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" |
| | return None |
| |
|
| | def to_legacy_cache(self) -> tuple: |
| | return tuple(self.states) |
| |
|
| | def print_kv_sizes(self) -> None: |
| | """Returns the size of the cached key/value states.""" |
| | for layer_idx, state in enumerate(self.states): |
| | if state.get("attn_state", None) is not None: |
| | key_state, value_state = state["attn_state"] |
| | |
| | key_size = key_state.element_size() * key_state.nelement() / (1024**2) |
| | value_size = value_state.element_size() * value_state.nelement() / (1024**2) |
| | print(key_state.shape, value_state.shape) |
| | print(f"Layer {layer_idx}: Attention. cache size: {key_size + value_size:.2f} MB") |
| | if state.get("conv_state", None) is not None: |
| | conv_state = state["conv_state"] |
| | |
| | conv_sizes = [] |
| | for conv in conv_state: |
| | conv_size = conv.element_size() * conv.nelement() / (1024**2) |
| | conv_sizes.append(conv_size) |
| | conv_size = sum(conv_sizes) |
| | print(f"Layer {layer_idx}: Convolution. cache size: {conv_size:.2f} MB") |
| | if state.get("ffn_state", None) is not None: |
| | ffn_state = state["ffn_state"] |
| | |
| | ffn_size = ffn_state.element_size() * ffn_state.nelement() / (1024**2) |
| | print(f"Layer {layer_idx}: FFN. cache size: {ffn_size:.2f} MB") |
| | if state.get("recurrent_state", None) is not None: |
| | recurrent_state = state["recurrent_state"] |
| | |
| | recurrent_size = recurrent_state.element_size() * recurrent_state.nelement() / (1024**2) |
| | print(f"Layer {layer_idx}: Recurrent. cache size: {recurrent_size:.2f} MB") |
| |
|