Harmony18090's picture
Add source batch 2/11
76f9669 verified
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple
from weakref import ReferenceType, ref
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PretrainedConfig, PreTrainedModel
__all__ = [
"QuantizedKVCache",
"initialize_hooked_kv_cache",
"register_key_hook",
"register_value_hook",
"KV_CACHE_ATTR",
]
KV_CACHE_ATTR = "kv_cache"
class QuantizedKVCache(InternalModule):
"""
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
hooked to trigger transforms and calibration hooks.
This module works by being registered as a submodule to attention modules via
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
kwargs with this module. This module adopts the functionality of the replaced cache,
preserving caching functionality such as sliding window attention, ect.
:param attn_module: parent attention module
"""
def __init__(self, config: PretrainedConfig, attn_module: Module):
super().__init__()
self.config = config
self.attn_module = ref(attn_module) # avoid circular reference
self.past_key_values: Optional[ReferenceType[Cache]] = None
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
return self(*args, **kwargs)
def forward(
self,
key_states: Tensor,
value_states: Tensor,
*args,
**kwargs,
) -> Tuple[Tensor, Tensor]:
# quantization
module = self.attn_module()
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled:
key_states = forward_quantize(module, key_states, "k", quant_args)
value_states = forward_quantize(module, value_states, "v", quant_args)
# original cache
if self.past_key_values is not None:
ret = self.past_key_values().update(
key_states, value_states, *args, **kwargs
)
else:
ret = (key_states, value_states)
self.past_key_values = None
return ret
def add_past_key_values(self, past_key_values: Optional[Cache]):
if past_key_values is not None:
self.past_key_values = ref(past_key_values)
else:
self.past_key_values = None
# ----- initialize ----- #
def _kv_cache_attention_hook(
module: Module, args: List[Any], kwargs: Dict[str, Any]
) -> Tuple[List[Any], Dict[str, Any]]:
"""
Hook which should be called before each quantized attention forward pass.
This hook dynamically replaces the `past_key_values` kwarg to the attention
forward function.
The original kvcache object is assigned to QuantizedKVCache().past_key_values
as a weakref to maintain original cache functionality and compute savings
"""
_past_kv_name = (
"past_key_values" # transformers#39956
if "past_key_values" in inspect.signature(module.forward).parameters
else "past_key_value"
)
past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
cache.add_past_key_values(past_key_values)
kwargs[_past_kv_name] = cache
return args, kwargs
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
"""
Initialize a `QuantizedKVCache` instance attached to attention
:param model: parent model of attention module
:param module: attention module to initialize with
"""
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
# ----- hooks ----- #
def register_key_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes post-rope key states as an argument and
returns the modified key states or `None`
:param module: attention module to add hook to
:param hook: key hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["key_states"])
if value is not None:
bound.arguments["key_states"] = value
return bound.args, bound.kwargs
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
def register_value_hook(
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
) -> RemovableHandle:
"""
Register a hook which takes value states as an argument and
returns the modified value states or `None`
:param module: attention module to add hook to
:param hook: value hook function
"""
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["value_states"])
if value is not None:
bound.arguments["value_states"] = value
return bound.args, bound.kwargs
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)