Jet-Nemotron-4B / kv_cache.py
t1101675's picture
Upload folder using huggingface_hub
f62ec09 verified
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any, Optional
import torch
import transformers
__all__ = ["JetCache"]
class JetNemotronCache(transformers.cache_utils.Cache):
def __init__(
self,
seen_tokens: int = 0
) -> JetNemotronCache:
self.states: list[dict[str, Any]] = []
self.layer_wise_states: dict[str, Any] = {}
self._base_seen_tokens = seen_tokens
self._seen_tokens = [] # Used in `generate` to keep tally of how many tokens the cache has seen
def __getitem__(self, layer_idx: int) -> dict[str, Any]:
if layer_idx < len(self):
return self.states[layer_idx]
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
for state in self.states:
yield state
def __len__(self):
return len(self.states)
def update(
self,
recurrent_state: torch.Tensor = None,
attn_state: tuple[torch.Tensor, torch.Tensor] = None,
conv_state: tuple[torch.Tensor] = None,
ffn_state: torch.Tensor = None,
layer_idx: int = 0,
offset: Optional[int] = 1,
increase_seen_tokens: bool = True,
cache_kwargs: dict[str, Any] = {},
) -> dict[str, Any]:
"""
Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
Args:
recurrent_state (`torch.Tensor`, `optional`):
The new recurrent state to cache.
attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
The new attention key/value states to cache.
conv_state (`Tuple[torch.Tensor]`, `optional`):
The new convolution state to cache.
layer_idx (`int`, defaults to 0):
The index of the layer to cache the states for.
offset (`int`, `optional`, defaults to 1):
The number of new tokens being processed.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass.
Return:
Dictionary of the updated state.
"""
if len(self._seen_tokens) <= layer_idx:
self._seen_tokens.append(self._base_seen_tokens)
# Update the number of seen tokens
if increase_seen_tokens:
self.increase_seen_tokens(layer_idx, offset)
if attn_state is not None:
input_size = attn_state[0].shape[-2]
window_size = cache_kwargs.get('window_size', None)
if not isinstance(attn_state, tuple) or len(attn_state) != 2:
raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
if len(self.states) <= layer_idx:
# in prefilling stage
state = dict(
recurrent_state=recurrent_state,
attn_state=attn_state,
conv_state=conv_state,
ffn_state=ffn_state
)
if attn_state is not None and window_size is not None:
# in prefilling stage, the cached and returned key/value states are different
# original key/value states are returned, but the cached states are the last `window_size` tokens
_key_state = attn_state[0][..., -window_size:, :]
_value_state = attn_state[1][..., -window_size:, :]
_attn_state = (_key_state, _value_state)
_state = dict(
recurrent_state=recurrent_state,
attn_state=_attn_state,
conv_state=conv_state,
ffn_state=ffn_state
)
self.states.append(_state)
else:
self.states.append(state)
else:
state = self.states[layer_idx]
if recurrent_state is not None:
state['recurrent_state'] = recurrent_state
if attn_state is not None:
key_state, value_state = state['attn_state']
assert window_size is None or key_state.shape[-2] <= window_size
if window_size is not None and key_state.shape[-2] == window_size and input_size == 1:
# DO NOT allocate new memory if the cache is full
# only works in decoding stage
# roll the key/value states to the left by `input_size`
key_state = key_state.roll(-input_size, -2)
value_state = value_state.roll(-input_size, -2)
# replace the last `input_size` tokens with the new key/value states
key_state[..., -input_size:, :] = attn_state[0]
value_state[..., -input_size:, :] = attn_state[1]
attn_state = (key_state, value_state)
else:
# <= window_size or not sliding window or chunk-prefilling (input_size > 1)
attn_state = (torch.cat([key_state, attn_state[0]], -2),
torch.cat([value_state, attn_state[1]], -2),)
state['attn_state'] = attn_state
if conv_state is not None:
state['conv_state'] = conv_state
if ffn_state is not None:
state['ffn_state'] = ffn_state
assert len(self.states) == len(self._seen_tokens)
return state
def trim_attn_state(self, layer_idx: int, window_size: int) -> None:
# handle the case when the input length of SWA > 1 and has a cache, especially the chunk-prefilling case
# this function is called after attention is donw
assert layer_idx < len(self.states), f"Layer index {layer_idx} out of range for states with length {len(self.states)}"
state = self.states[layer_idx]
assert state["attn_state"] is not None, f"Layer {layer_idx} does not have an attention state"
key_state, value_state = state["attn_state"]
if key_state.shape[-2] > window_size:
state["attn_state"] = (
key_state[..., -window_size:, :],
value_state[..., -window_size:, :],
)
def increase_seen_tokens(self, layer_idx: int, offset: int = 1) -> None:
"""Increases the number of seen tokens for the layer `layer_idx` by `offset`."""
self._seen_tokens[layer_idx] += offset
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self._seen_tokens) <= layer_idx:
return self._base_seen_tokens
return self._seen_tokens[layer_idx]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
return None
def to_legacy_cache(self) -> tuple:
return tuple(self.states)
def print_kv_sizes(self) -> None:
"""Returns the size of the cached key/value states."""
for layer_idx, state in enumerate(self.states):
if state.get("attn_state", None) is not None:
key_state, value_state = state["attn_state"]
# compute state size in MB
key_size = key_state.element_size() * key_state.nelement() / (1024**2)
value_size = value_state.element_size() * value_state.nelement() / (1024**2)
print(key_state.shape, value_state.shape)
print(f"Layer {layer_idx}: Attention. cache size: {key_size + value_size:.2f} MB")
if state.get("conv_state", None) is not None:
conv_state = state["conv_state"]
# compute state size in MB
conv_sizes = []
for conv in conv_state:
conv_size = conv.element_size() * conv.nelement() / (1024**2)
conv_sizes.append(conv_size)
conv_size = sum(conv_sizes)
print(f"Layer {layer_idx}: Convolution. cache size: {conv_size:.2f} MB")
if state.get("ffn_state", None) is not None:
ffn_state = state["ffn_state"]
# compute state size in MB
ffn_size = ffn_state.element_size() * ffn_state.nelement() / (1024**2)
print(f"Layer {layer_idx}: FFN. cache size: {ffn_size:.2f} MB")
if state.get("recurrent_state", None) is not None:
recurrent_state = state["recurrent_state"]
# compute state size in MB
recurrent_size = recurrent_state.element_size() * recurrent_state.nelement() / (1024**2)
print(f"Layer {layer_idx}: Recurrent. cache size: {recurrent_size:.2f} MB")