| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | from typing import Any, Callable, Dict, List, Optional, Tuple |
| | from weakref import ReferenceType, ref |
| |
|
| | from compressed_tensors.quantization.lifecycle.forward import forward_quantize |
| | from compressed_tensors.utils import getattr_chain |
| | from compressed_tensors.utils.internal import InternalModule |
| | from torch import Tensor |
| | from torch.nn import Module |
| | from torch.utils.hooks import RemovableHandle |
| | from transformers import Cache, PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | __all__ = [ |
| | "QuantizedKVCache", |
| | "initialize_hooked_kv_cache", |
| | "register_key_hook", |
| | "register_value_hook", |
| | "KV_CACHE_ATTR", |
| | ] |
| |
|
| |
|
| | KV_CACHE_ATTR = "kv_cache" |
| |
|
| |
|
| | class QuantizedKVCache(InternalModule): |
| | """ |
| | QuantizedKVCache module which wraps the functionality of any existing kvcache args. |
| | Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be |
| | hooked to trigger transforms and calibration hooks. |
| | |
| | This module works by being registered as a submodule to attention modules via |
| | `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values` |
| | kwargs with this module. This module adopts the functionality of the replaced cache, |
| | preserving caching functionality such as sliding window attention, ect. |
| | |
| | :param attn_module: parent attention module |
| | """ |
| |
|
| | def __init__(self, config: PretrainedConfig, attn_module: Module): |
| | super().__init__() |
| | self.config = config |
| | self.attn_module = ref(attn_module) |
| | self.past_key_values: Optional[ReferenceType[Cache]] = None |
| |
|
| | def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: |
| | return self(*args, **kwargs) |
| |
|
| | def forward( |
| | self, |
| | key_states: Tensor, |
| | value_states: Tensor, |
| | *args, |
| | **kwargs, |
| | ) -> Tuple[Tensor, Tensor]: |
| | |
| | module = self.attn_module() |
| | quant_args_attr = "quantization_scheme.input_activations" |
| | quant_args = getattr_chain(module, quant_args_attr, None) |
| | quant_enabled = getattr(module, "quantization_enabled", True) |
| | if quant_args is not None and quant_enabled: |
| | key_states = forward_quantize(module, key_states, "k", quant_args) |
| | value_states = forward_quantize(module, value_states, "v", quant_args) |
| |
|
| | |
| | if self.past_key_values is not None: |
| | ret = self.past_key_values().update( |
| | key_states, value_states, *args, **kwargs |
| | ) |
| | else: |
| | ret = (key_states, value_states) |
| | self.past_key_values = None |
| |
|
| | return ret |
| |
|
| | def add_past_key_values(self, past_key_values: Optional[Cache]): |
| | if past_key_values is not None: |
| | self.past_key_values = ref(past_key_values) |
| | else: |
| | self.past_key_values = None |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _kv_cache_attention_hook( |
| | module: Module, args: List[Any], kwargs: Dict[str, Any] |
| | ) -> Tuple[List[Any], Dict[str, Any]]: |
| | """ |
| | Hook which should be called before each quantized attention forward pass. |
| | This hook dynamically replaces the `past_key_values` kwarg to the attention |
| | forward function. |
| | |
| | The original kvcache object is assigned to QuantizedKVCache().past_key_values |
| | as a weakref to maintain original cache functionality and compute savings |
| | """ |
| | _past_kv_name = ( |
| | "past_key_values" |
| | if "past_key_values" in inspect.signature(module.forward).parameters |
| | else "past_key_value" |
| | ) |
| | past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None) |
| |
|
| | cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| | cache.add_past_key_values(past_key_values) |
| | kwargs[_past_kv_name] = cache |
| |
|
| | return args, kwargs |
| |
|
| |
|
| | def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module): |
| | """ |
| | Initialize a `QuantizedKVCache` instance attached to attention |
| | |
| | :param model: parent model of attention module |
| | :param module: attention module to initialize with |
| | """ |
| | if not hasattr(module, KV_CACHE_ATTR): |
| | module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) |
| | module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def register_key_hook( |
| | module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] |
| | ) -> RemovableHandle: |
| | """ |
| | Register a hook which takes post-rope key states as an argument and |
| | returns the modified key states or `None` |
| | |
| | :param module: attention module to add hook to |
| | :param hook: key hook function |
| | """ |
| | kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| |
|
| | def _hook(cache: QuantizedKVCache, args, kwargs): |
| | bound = inspect.signature(cache.forward).bind(*args, **kwargs) |
| | value = hook(module, bound.arguments["key_states"]) |
| | if value is not None: |
| | bound.arguments["key_states"] = value |
| |
|
| | return bound.args, bound.kwargs |
| |
|
| | return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) |
| |
|
| |
|
| | def register_value_hook( |
| | module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] |
| | ) -> RemovableHandle: |
| | """ |
| | Register a hook which takes value states as an argument and |
| | returns the modified value states or `None` |
| | |
| | :param module: attention module to add hook to |
| | :param hook: value hook function |
| | """ |
| | kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| |
|
| | def _hook(cache: QuantizedKVCache, args, kwargs): |
| | bound = inspect.signature(cache.forward).bind(*args, **kwargs) |
| | value = hook(module, bound.arguments["value_states"]) |
| | if value is not None: |
| | bound.arguments["value_states"] = value |
| |
|
| | return bound.args, bound.kwargs |
| |
|
| | return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) |
| |
|