File size: 6,472 Bytes
76f9669 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple
from weakref import ReferenceType, ref
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PretrainedConfig, PreTrainedModel
__all__ = [
"QuantizedKVCache",
"initialize_hooked_kv_cache",
"register_key_hook",
"register_value_hook",
"KV_CACHE_ATTR",
]
KV_CACHE_ATTR = "kv_cache"
class QuantizedKVCache(InternalModule):
"""
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
hooked to trigger transforms and calibration hooks.
This module works by being registered as a submodule to attention modules via
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
kwargs with this module. This module adopts the functionality of the replaced cache,
preserving caching functionality such as sliding window attention, ect.
:param attn_module: parent attention module
"""
def __init__(self, config: PretrainedConfig, attn_module: Module):
super().__init__()
self.config = config
self.attn_module = ref(attn_module) # avoid circular reference
self.past_key_values: Optional[ReferenceType[Cache]] = None
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
return self(*args, **kwargs)
def forward(
self,
key_states: Tensor,
value_states: Tensor,
*args,
**kwargs,
) -> Tuple[Tensor, Tensor]:
# quantization
module = self.attn_module()
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled:
key_states = forward_quantize(module, key_states, "k", quant_args)
value_states = forward_quantize(module, value_states, "v", quant_args)
# original cache
if self.past_key_values is not None:
ret = self.past_key_values().update(
key_states, value_states, *args, **kwargs
)
else:
ret = (key_states, value_states)
self.past_key_values = None
return ret
def add_past_key_values(self, past_key_values: Optional[Cache]):
if past_key_values is not None:
self.past_key_values = ref(past_key_values)
else:
self.past_key_values = None
# ----- initialize ----- #
def _kv_cache_attention_hook(
module: Module, args: List[Any], kwargs: Dict[str, Any]
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Hook which should be called before each quantized attention forward pass.
This hook dynamically replaces the `past_key_values` kwarg to the attention
forward function.
The original kvcache object is assigned to QuantizedKVCache().past_key_values
as a weakref to maintain original cache functionality and compute savings
"""
_past_kv_name = (
"past_key_values" # transformers#39956
if "past_key_values" in inspect.signature(module.forward).parameters
else "past_key_value"
)
past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
cache.add_past_key_values(past_key_values)
kwargs[_past_kv_name] = cache
return args, kwargs
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
"""
Initialize a `QuantizedKVCache` instance attached to attention
:param model: parent model of attention module
:param module: attention module to initialize with
"""
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
# ----- hooks ----- #
def register_key_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes post-rope key states as an argument and
returns the modified key states or `None`
:param module: attention module to add hook to
:param hook: key hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["key_states"])
if value is not None:
bound.arguments["key_states"] = value
return bound.args, bound.kwargs
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
def register_value_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes value states as an argument and
returns the modified value states or `None`
:param module: attention module to add hook to
:param hook: value hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["value_states"])
if value is not None:
bound.arguments["value_states"] = value
return bound.args, bound.kwargs
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
|