File size: 9,628 Bytes
f62ec09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright 2025 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import Any, Optional

import torch
import transformers

__all__ = ["JetCache"]


class JetNemotronCache(transformers.cache_utils.Cache):

    def __init__(
        self,
        seen_tokens: int = 0
    ) -> JetNemotronCache:

        self.states: list[dict[str, Any]] = []
        self.layer_wise_states: dict[str, Any] = {}

        self._base_seen_tokens = seen_tokens 
        self._seen_tokens = []  # Used in `generate` to keep tally of how many tokens the cache has seen

    def __getitem__(self, layer_idx: int) -> dict[str, Any]:
        if layer_idx < len(self):
            return self.states[layer_idx]
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __iter__(self):
        for state in self.states:
            yield state

    def __len__(self):
        return len(self.states)

    def update(
        self,
        recurrent_state: torch.Tensor = None,
        attn_state: tuple[torch.Tensor, torch.Tensor] = None,
        conv_state: tuple[torch.Tensor] = None,
        ffn_state: torch.Tensor = None,
        layer_idx: int = 0,
        offset: Optional[int] = 1,
        increase_seen_tokens: bool = True,
        cache_kwargs: dict[str, Any] = {},
    ) -> dict[str, Any]:
        """
        Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.

        Args:
            recurrent_state (`torch.Tensor`, `optional`):
                The new recurrent state to cache.
            attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
                The new attention key/value states to cache.
            conv_state (`Tuple[torch.Tensor]`, `optional`):
                The new convolution state to cache.
            layer_idx (`int`, defaults to 0):
                The index of the layer to cache the states for.
            offset (`int`, `optional`, defaults to 1):
                The number of new tokens being processed.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass.

        Return:
            Dictionary of the updated state.
        """
        if len(self._seen_tokens) <= layer_idx:
            self._seen_tokens.append(self._base_seen_tokens)

        # Update the number of seen tokens
        if increase_seen_tokens:
            self.increase_seen_tokens(layer_idx, offset)
            
        if attn_state is not None:
            input_size = attn_state[0].shape[-2]
            window_size = cache_kwargs.get('window_size', None)
            if not isinstance(attn_state, tuple) or len(attn_state) != 2:
                raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
        if len(self.states) <= layer_idx:
            # in prefilling stage
            state = dict(
                recurrent_state=recurrent_state,
                attn_state=attn_state,
                conv_state=conv_state,
                ffn_state=ffn_state
            )
            if attn_state is not None and window_size is not None:
                # in prefilling stage, the cached and returned key/value states are different
                # original key/value states are returned, but the cached states are the last `window_size` tokens
                _key_state = attn_state[0][..., -window_size:, :]
                _value_state = attn_state[1][..., -window_size:, :]

                _attn_state = (_key_state, _value_state)
                _state = dict(
                    recurrent_state=recurrent_state,
                    attn_state=_attn_state,
                    conv_state=conv_state,
                    ffn_state=ffn_state
                )
                self.states.append(_state)
            else:
                self.states.append(state)
        else:
            state = self.states[layer_idx]
            if recurrent_state is not None:
                state['recurrent_state'] = recurrent_state
            if attn_state is not None:
                key_state, value_state = state['attn_state']
                assert window_size is None or key_state.shape[-2] <= window_size
                if window_size is not None and key_state.shape[-2] == window_size and input_size == 1:
                    # DO NOT allocate new memory if the cache is full
                    # only works in decoding stage
                    # roll the key/value states to the left by `input_size`
                                        
                    key_state = key_state.roll(-input_size, -2)
                    value_state = value_state.roll(-input_size, -2)
                                        
                    # replace the last `input_size` tokens with the new key/value states
                    key_state[..., -input_size:, :] = attn_state[0]
                    value_state[..., -input_size:, :] = attn_state[1]
                    
                    attn_state = (key_state, value_state)
                else:
                    # <= window_size or not sliding window or chunk-prefilling (input_size > 1)
                    attn_state = (torch.cat([key_state, attn_state[0]], -2),
                                  torch.cat([value_state, attn_state[1]], -2),)
                state['attn_state'] = attn_state
            if conv_state is not None:
                state['conv_state'] = conv_state
            if ffn_state is not None:
                state['ffn_state'] = ffn_state

        assert len(self.states) == len(self._seen_tokens)

        return state

    def trim_attn_state(self, layer_idx: int, window_size: int) -> None:
        # handle the case when the input length of SWA > 1 and has a cache, especially the chunk-prefilling case
        # this function is called after attention is donw
        assert layer_idx < len(self.states), f"Layer index {layer_idx} out of range for states with length {len(self.states)}"
        state = self.states[layer_idx]
        assert state["attn_state"] is not None, f"Layer {layer_idx} does not have an attention state"
        key_state, value_state = state["attn_state"]
        if key_state.shape[-2] > window_size:
            state["attn_state"] = (
                key_state[..., -window_size:, :],
                value_state[..., -window_size:, :],
            )

    def increase_seen_tokens(self, layer_idx: int, offset: int = 1) -> None:
        """Increases the number of seen tokens for the layer `layer_idx` by `offset`."""
        self._seen_tokens[layer_idx] += offset

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self._seen_tokens) <= layer_idx:
            return self._base_seen_tokens
        return self._seen_tokens[layer_idx]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
        return None

    def to_legacy_cache(self) -> tuple:
        return tuple(self.states)

    def print_kv_sizes(self) -> None:
        """Returns the size of the cached key/value states."""
        for layer_idx, state in enumerate(self.states):
            if state.get("attn_state", None) is not None:
                key_state, value_state = state["attn_state"]
                # compute state size in MB
                key_size = key_state.element_size() * key_state.nelement() / (1024**2)
                value_size = value_state.element_size() * value_state.nelement() / (1024**2)
                print(key_state.shape, value_state.shape)
                print(f"Layer {layer_idx}: Attention. cache size: {key_size + value_size:.2f} MB")
            if state.get("conv_state", None) is not None:
                conv_state = state["conv_state"]
                # compute state size in MB
                conv_sizes = []
                for conv in conv_state:
                    conv_size = conv.element_size() * conv.nelement() / (1024**2)
                    conv_sizes.append(conv_size)
                conv_size = sum(conv_sizes)
                print(f"Layer {layer_idx}: Convolution. cache size: {conv_size:.2f} MB")
            if state.get("ffn_state", None) is not None:
                ffn_state = state["ffn_state"]
                # compute state size in MB
                ffn_size = ffn_state.element_size() * ffn_state.nelement() / (1024**2)
                print(f"Layer {layer_idx}: FFN. cache size: {ffn_size:.2f} MB")
            if state.get("recurrent_state", None) is not None:
                recurrent_state = state["recurrent_state"]
                # compute state size in MB
                recurrent_size = recurrent_state.element_size() * recurrent_state.nelement() / (1024**2)
                print(f"Layer {layer_idx}: Recurrent. cache size: {recurrent_size:.2f} MB")