File size: 6,472 Bytes
76f9669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple
from weakref import ReferenceType, ref

from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PretrainedConfig, PreTrainedModel


__all__ = [
    "QuantizedKVCache",
    "initialize_hooked_kv_cache",
    "register_key_hook",
    "register_value_hook",
    "KV_CACHE_ATTR",
]


KV_CACHE_ATTR = "kv_cache"


class QuantizedKVCache(InternalModule):
    """
    QuantizedKVCache module which wraps the functionality of any existing kvcache args.
    Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
    hooked to trigger transforms and calibration hooks.

    This module works by being registered as a submodule to attention modules via
    `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
    kwargs with this module. This module adopts the functionality of the replaced cache,
    preserving caching functionality such as sliding window attention, ect.

    :param attn_module: parent attention module
    """

    def __init__(self, config: PretrainedConfig, attn_module: Module):
        super().__init__()
        self.config = config
        self.attn_module = ref(attn_module)  # avoid circular reference
        self.past_key_values: Optional[ReferenceType[Cache]] = None

    def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
        return self(*args, **kwargs)

    def forward(
        self,
        key_states: Tensor,
        value_states: Tensor,
        *args,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        # quantization
        module = self.attn_module()
        quant_args_attr = "quantization_scheme.input_activations"
        quant_args = getattr_chain(module, quant_args_attr, None)
        quant_enabled = getattr(module, "quantization_enabled", True)
        if quant_args is not None and quant_enabled:
            key_states = forward_quantize(module, key_states, "k", quant_args)
            value_states = forward_quantize(module, value_states, "v", quant_args)

        # original cache
        if self.past_key_values is not None:
            ret = self.past_key_values().update(
                key_states, value_states, *args, **kwargs
            )
        else:
            ret = (key_states, value_states)
        self.past_key_values = None

        return ret

    def add_past_key_values(self, past_key_values: Optional[Cache]):
        if past_key_values is not None:
            self.past_key_values = ref(past_key_values)
        else:
            self.past_key_values = None


# ----- initialize ----- #


def _kv_cache_attention_hook(
    module: Module, args: List[Any], kwargs: Dict[str, Any]
) -> Tuple[List[Any], Dict[str, Any]]:
    """
    Hook which should be called before each quantized attention forward pass.
    This hook dynamically replaces the `past_key_values` kwarg to the attention
    forward function.

    The original kvcache object is assigned to QuantizedKVCache().past_key_values
    as a weakref to maintain original cache functionality and compute savings
    """
    _past_kv_name = (
        "past_key_values"  # transformers#39956
        if "past_key_values" in inspect.signature(module.forward).parameters
        else "past_key_value"
    )
    past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)

    cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
    cache.add_past_key_values(past_key_values)
    kwargs[_past_kv_name] = cache

    return args, kwargs


def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
    """
    Initialize a `QuantizedKVCache` instance attached to attention

    :param model: parent model of attention module
    :param module: attention module to initialize with
    """
    if not hasattr(module, KV_CACHE_ATTR):
        module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
        module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)


# ----- hooks ----- #


def register_key_hook(
    module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
    """
    Register a hook which takes post-rope key states as an argument and
    returns the modified key states or `None`

    :param module: attention module to add hook to
    :param hook: key hook function
    """
    kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

    def _hook(cache: QuantizedKVCache, args, kwargs):
        bound = inspect.signature(cache.forward).bind(*args, **kwargs)
        value = hook(module, bound.arguments["key_states"])
        if value is not None:
            bound.arguments["key_states"] = value

        return bound.args, bound.kwargs

    return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)


def register_value_hook(
    module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
    """
    Register a hook which takes value states as an argument and
    returns the modified value states or `None`

    :param module: attention module to add hook to
    :param hook: value hook function
    """
    kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

    def _hook(cache: QuantizedKVCache, args, kwargs):
        bound = inspect.signature(cache.forward).bind(*args, **kwargs)
        value = hook(module, bound.arguments["value_states"])
        if value is not None:
            bound.arguments["value_states"] = value

        return bound.args, bound.kwargs

    return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)