|
|
import copy |
|
|
import importlib.metadata |
|
|
import json |
|
|
import os |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from packaging import version |
|
|
|
|
|
from transformers.utils import is_hqq_available, is_optimum_quanto_available, logging |
|
|
|
|
|
from transformers.cache_utils import CacheConfig, QuantizedCacheConfig, QuantizedCache |
|
|
|
|
|
if is_hqq_available(): |
|
|
from hqq.core.quantize import Quantizer as HQQQuantizer |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class SQuatCacheConfig(QuantizedCacheConfig): |
|
|
""" |
|
|
Configuration class for SQuat cache settings. |
|
|
""" |
|
|
def __init__(self, |
|
|
quant_group_size: Optional[int] = 64, |
|
|
squat_lambda: Optional[float] = 0.0001, |
|
|
subspace_dim: Optional[int] = 5, |
|
|
shared_svd: Optional[bool] = True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.cache_implementation = "squat" |
|
|
self.quant_group_size = quant_group_size |
|
|
self.squat_lambda = squat_lambda |
|
|
self.subspace_dim = subspace_dim |
|
|
self.shared_svd = shared_svd |
|
|
|
|
|
|
|
|
class SQuatCache(QuantizedCache): |
|
|
""" |
|
|
Quantized Cache class that uses `SQuat` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. |
|
|
|
|
|
Parameters: |
|
|
cache_config (`SQuatCacheConfig`): |
|
|
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> # Run pip install quanto first if you don't have it yet |
|
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SQuatCache, SQuatCacheConfig |
|
|
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
|
|
|
|
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
|
|
|
|
|
>>> # Prepare a cache class and pass it to model's forward |
|
|
>>> cache_config = SQuatCacheConfig(nbits=4) |
|
|
>>> past_key_values = SQuatCache(cache_config=cache_config) |
|
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
|
|
>>> outputs.past_key_values # access cache filled with key/values from generation |
|
|
SQuatCache() |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_config: CacheConfig) -> None: |
|
|
super().__init__(cache_config) |
|
|
|
|
|
if is_optimum_quanto_available(): |
|
|
optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) |
|
|
if optimum_quanto_version <= version.parse("0.2.5"): |
|
|
raise ImportError( |
|
|
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." |
|
|
) |
|
|
from optimum.quanto import MaxOptimizer, qint2, qint4 |
|
|
|
|
|
if self.nbits not in [2, 4]: |
|
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") |
|
|
|
|
|
if self.axis_key not in [0, -1]: |
|
|
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") |
|
|
|
|
|
if self.axis_value not in [0, -1]: |
|
|
raise ValueError( |
|
|
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" |
|
|
) |
|
|
|
|
|
self.qtype = qint4 if self.nbits == 4 else qint2 |
|
|
self.optimizer = MaxOptimizer() |
|
|
|
|
|
self.auxiliary_matrices_A = [] |
|
|
self.auxiliary_matrices_P = [] |
|
|
self.squat_lambda = getattr(cache_config, "squat_lambda", 0.0005) |
|
|
self.squat_q_group_size = getattr(cache_config, "quant_group_size", 64) |
|
|
self.squat_subspace_dim = getattr(cache_config, "subspace_dim", 20) |
|
|
self.squat_shared_svd = getattr(cache_config, "shared_svd", True) |
|
|
|
|
|
def update( |
|
|
self, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
layer_idx: int, |
|
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
if layer_idx == 0: |
|
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if len(self.key_cache) < layer_idx: |
|
|
raise ValueError("SQuatCache does not support model usage where layers are skipped. Use DynamicCache.") |
|
|
elif len(self.key_cache) == layer_idx: |
|
|
if len(self.auxiliary_matrices_A) == layer_idx: |
|
|
Ainv_t, P_inv = self._get_query_subspace(key_states, cache_kwargs["query_states"], cache_kwargs["attention_mask"]) |
|
|
self.auxiliary_matrices_A.append(Ainv_t) |
|
|
self.auxiliary_matrices_P.append(P_inv) |
|
|
|
|
|
if key_states.shape[-2] % self.residual_length != 0: |
|
|
if key_states.shape[-2] < self.residual_length: |
|
|
key_states_quant = None |
|
|
key_states_full = key_states |
|
|
value_states_quant = None |
|
|
value_states_full = value_states |
|
|
else: |
|
|
key_states_quant = key_states[:, :, :-(key_states.shape[-2] % self.residual_length), :].contiguous() |
|
|
key_states_full = key_states[:, :, -(key_states.shape[-2] % self.residual_length):, :].contiguous() |
|
|
value_states_quant = value_states[:, :, :-(value_states.shape[-2] % self.residual_length), :].contiguous() |
|
|
value_states_full = value_states[:, :, -(value_states.shape[-2] % self.residual_length):, :].contiguous() |
|
|
else: |
|
|
key_states_quant = key_states |
|
|
key_states_full = None |
|
|
value_states_quant = value_states |
|
|
value_states_full = None |
|
|
if key_states_quant is not None: |
|
|
self._quantized_key_cache.append(self.squat_quantize_key(key_states_quant, self.squat_q_group_size, Ainv_t, P_inv)) |
|
|
self._quantized_value_cache.append(self._quantize(value_states_quant, axis=self.axis_value)) |
|
|
else: |
|
|
self._quantized_key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
|
self._quantized_value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
|
if key_states_full is not None: |
|
|
self.key_cache.append(key_states_full) |
|
|
self.value_cache.append(value_states_full) |
|
|
else: |
|
|
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
|
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
|
|
|
|
|
keys_to_return, values_to_return = key_states, value_states |
|
|
|
|
|
else: |
|
|
if len(self._quantized_key_cache[layer_idx]) == 0: |
|
|
dequant_key = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
|
else: |
|
|
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) |
|
|
if len(self._quantized_value_cache[layer_idx]) == 0: |
|
|
dequant_value = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
|
else: |
|
|
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) |
|
|
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] |
|
|
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] |
|
|
|
|
|
keys_to_return = torch.cat(keys_to_return, dim=-2) |
|
|
values_to_return = torch.cat(values_to_return, dim=-2) |
|
|
if ( |
|
|
self.key_cache[layer_idx].dim() == 4 |
|
|
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length |
|
|
): |
|
|
keys_to_quantize = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
|
quantized_key = self.squat_quantize_key( |
|
|
keys_to_quantize, self.squat_q_group_size, self.auxiliary_matrices_A[layer_idx], |
|
|
self.auxiliary_matrices_P[layer_idx] |
|
|
) |
|
|
self._quantized_key_cache[layer_idx] = self._quantize( |
|
|
torch.cat([dequant_key, self._dequantize(quantized_key)], dim=2), axis=self.axis_key |
|
|
) |
|
|
self._quantized_value_cache[layer_idx] = self._quantize( |
|
|
values_to_return.contiguous(), axis=self.axis_value |
|
|
) |
|
|
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
|
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
|
|
else: |
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
|
|
return keys_to_return, values_to_return |
|
|
|
|
|
def _get_query_subspace(self, key_states, query_states, attention_mask=None): |
|
|
bsz = query_states.shape[0] |
|
|
kv_nh = key_states.shape[1] |
|
|
head_dim = query_states.shape[3] |
|
|
num_key_value_groups = query_states.shape[1] // key_states.shape[1] |
|
|
subspace_dim = min(self.squat_subspace_dim, num_key_value_groups*key_states.shape[2]) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
if attention_mask.shape[2] == attention_mask.shape[3]-1: |
|
|
attention_mask = attention_mask[:,:,:,:attention_mask.shape[2]] |
|
|
|
|
|
last_row_mask = attention_mask[:, :, -1, :] |
|
|
|
|
|
valid_tokens = (last_row_mask == 0).squeeze(1) |
|
|
|
|
|
|
|
|
query_subspace = [] |
|
|
for b in range(bsz): |
|
|
|
|
|
batch_valid = valid_tokens[b] |
|
|
|
|
|
batch_query = query_states[b] |
|
|
batch_valid_query = batch_query[:, batch_valid, :] |
|
|
|
|
|
valid_query_states_matrix = batch_valid_query.reshape(kv_nh, -1, head_dim) |
|
|
U, S, Vh = torch.linalg.svd(valid_query_states_matrix.float(), full_matrices=False) |
|
|
S_subspace = torch.diag_embed(S[:, :subspace_dim]).to(valid_query_states_matrix.dtype) |
|
|
Vh_subspace = Vh[:, :subspace_dim, :].to(valid_query_states_matrix.dtype) |
|
|
batch_query_subspace = torch.matmul(S_subspace, Vh_subspace) |
|
|
|
|
|
query_subspace.append(batch_query_subspace) |
|
|
if self.squat_shared_svd: |
|
|
break |
|
|
|
|
|
|
|
|
query_subspace = torch.stack(query_subspace) |
|
|
else: |
|
|
query_states_matrix = query_states.reshape(bsz, kv_nh, -1, head_dim) |
|
|
U, S, Vh = torch.linalg.svd(query_states_matrix.float(), full_matrices=False) |
|
|
S_subspace = torch.diag_embed(S[:, :, :subspace_dim]).to(query_states_matrix.dtype) |
|
|
Vh_subspace = Vh[:, :, :subspace_dim, :].to(query_states_matrix.dtype) |
|
|
|
|
|
|
|
|
query_subspace = torch.matmul(S_subspace, Vh_subspace) |
|
|
|
|
|
if self.squat_shared_svd: |
|
|
query_subspace = query_subspace[0:1, ...] |
|
|
|
|
|
|
|
|
Ainv_t = self._generate_At_inv(self.squat_q_group_size, query_subspace.float(), lamb=self.squat_lambda) |
|
|
P_inv = torch.inverse(Ainv_t[-1]) |
|
|
|
|
|
return Ainv_t, P_inv |
|
|
|
|
|
def _generate_At_inv(self, quant_group_size, my_Qhat, lamb=1, tol=1e-7): |
|
|
""" |
|
|
Generate a list of T matrices where the t-th matrix has dimension (t*g, t*g). |
|
|
|
|
|
Parameters: |
|
|
- quant_group_size (int): Factor for matrix dimension scaling |
|
|
- lamb (float): Scaling factor for the final term |
|
|
- my_Qhat (torch.Tensor): A matrix of size (d, d) |
|
|
|
|
|
Returns: |
|
|
- List[torch.Tensor]: List of int(head_dim/quant_group_size) matrices |
|
|
""" |
|
|
|
|
|
bs, kv_nh, subspace_dim, head_dim = my_Qhat.shape |
|
|
T = (head_dim+quant_group_size-1)//quant_group_size |
|
|
matrices = [None] * T |
|
|
device = my_Qhat.device |
|
|
I = torch.eye(head_dim, device=device) |
|
|
|
|
|
A_T = I.expand(bs, kv_nh, head_dim, head_dim) + lamb * torch.matmul( |
|
|
my_Qhat.transpose(-1, -2), my_Qhat |
|
|
) |
|
|
matrices[T - 1] = A_T |
|
|
|
|
|
for t in range(T - 1, 0, -1): |
|
|
current_dim = t * quant_group_size |
|
|
|
|
|
|
|
|
M_t1 = A_T[:, :, :current_dim, :current_dim] |
|
|
N_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, :current_dim] |
|
|
O_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, current_dim : current_dim + quant_group_size] |
|
|
|
|
|
|
|
|
I_mat = torch.eye(quant_group_size, device=device) |
|
|
O_t1_inv = torch.inverse(O_t1 + tol * I_mat.expand(bs, kv_nh, quant_group_size, quant_group_size)) |
|
|
A_t = M_t1 - torch.matmul(N_t1.transpose(-1, -2), torch.matmul(O_t1_inv, N_t1)) |
|
|
matrices[t - 1] = A_t[:, :, :, -quant_group_size:] |
|
|
|
|
|
|
|
|
A_T = A_t |
|
|
return matrices |
|
|
|
|
|
def squat_quantize_key(self, key_states, quant_group_size, Ainv_t, P_inv): |
|
|
|
|
|
bsz, nh, seq_len, hidden_dim = key_states.shape |
|
|
dtype = key_states.dtype |
|
|
T = (hidden_dim+quant_group_size-1)//quant_group_size |
|
|
key_states_dequant = [] |
|
|
group = key_states |
|
|
for i in range(T): |
|
|
key_states_quant_this_quant_group = self._quantize( |
|
|
group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size].contiguous(), |
|
|
axis=self.axis_key |
|
|
) |
|
|
dequantized = self._dequantize(key_states_quant_this_quant_group) |
|
|
|
|
|
if i < T - 1: |
|
|
d_vec = ( |
|
|
dequantized |
|
|
- group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size] |
|
|
).float() |
|
|
H_t = Ainv_t[i] |
|
|
B_t = P_inv[ |
|
|
:, :, (i + 1) * quant_group_size :, : (i + 1) * quant_group_size |
|
|
] |
|
|
update = torch.matmul( |
|
|
torch.matmul(d_vec, H_t.transpose(-2, -1)), B_t.transpose(-2, -1) |
|
|
) |
|
|
group[:, :, :, (i + 1) * quant_group_size :] = ( |
|
|
group[:, :, :, (i + 1) * quant_group_size :] + update |
|
|
) |
|
|
|
|
|
key_states_dequant.append(dequantized) |
|
|
|
|
|
key_states_dequant = torch.cat(key_states_dequant, dim=3) |
|
|
key_states_quant = self._quantize(key_states_dequant, axis=self.axis_key) |
|
|
return key_states_quant |
|
|
|
|
|
|
|
|
class QuantoSQuatCache(SQuatCache): |
|
|
|
|
|
def __init__(self, cache_config: CacheConfig) -> None: |
|
|
super().__init__(cache_config) |
|
|
|
|
|
if is_optimum_quanto_available(): |
|
|
optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) |
|
|
if optimum_quanto_version <= version.parse("0.2.5"): |
|
|
raise ImportError( |
|
|
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." |
|
|
) |
|
|
from optimum.quanto import MaxOptimizer, qint2, qint4 |
|
|
|
|
|
if self.nbits not in [2, 4]: |
|
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") |
|
|
|
|
|
if self.axis_key not in [0, -1]: |
|
|
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") |
|
|
|
|
|
if self.axis_value not in [0, -1]: |
|
|
raise ValueError( |
|
|
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" |
|
|
) |
|
|
|
|
|
self.qtype = qint4 if self.nbits == 4 else qint2 |
|
|
self.optimizer = MaxOptimizer() |
|
|
|
|
|
def _quantize(self, tensor, axis): |
|
|
|
|
|
if is_optimum_quanto_available(): |
|
|
from optimum.quanto import quantize_weight |
|
|
|
|
|
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) |
|
|
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) |
|
|
return qtensor |
|
|
|
|
|
def _dequantize(self, qtensor): |
|
|
return qtensor.dequantize() |
|
|
|
|
|
|
|
|
class HQQSQuatCache(SQuatCache): |
|
|
|
|
|
def __init__(self, cache_config: CacheConfig) -> None: |
|
|
super().__init__(cache_config) |
|
|
if self.nbits not in [1, 2, 3, 4, 8]: |
|
|
raise ValueError( |
|
|
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" |
|
|
) |
|
|
|
|
|
if self.axis_key not in [0, 1]: |
|
|
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") |
|
|
|
|
|
if self.axis_value not in [0, 1]: |
|
|
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") |
|
|
|
|
|
self.quantizer = HQQQuantizer |
|
|
|
|
|
def _quantize(self, tensor, axis): |
|
|
qtensor, meta = self.quantizer.quantize( |
|
|
tensor, |
|
|
axis=axis, |
|
|
device=self.device, |
|
|
compute_dtype=self.compute_dtype, |
|
|
nbits=self.nbits, |
|
|
group_size=self.q_group_size, |
|
|
) |
|
|
meta["compute_dtype"] = self.compute_dtype |
|
|
self.quantizer.cuda(qtensor, meta=meta, device=self.device) |
|
|
meta["scale"] = meta["scale"].to(qtensor.device) |
|
|
meta["zero"] = meta["zero"].to(qtensor.device) |
|
|
return qtensor, meta |
|
|
|
|
|
def _dequantize(self, qtensor): |
|
|
quant_tensor, meta = qtensor |
|
|
tensor = self.quantizer.dequantize(quant_tensor, meta) |
|
|
return tensor |
|
|
|
|
|
|
|
|
SQUAT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoSQuatCache, "HQQ": HQQSQuatCache} |
|
|
|
|
|
def generate(model, generation_config=None, backend="quanto", nbits=2, quant_group_size=64, residual_length=32, squat_lambda=0.001, subspace_dim=20, shared_svd=True, **kwargs): |
|
|
"""Custom generate function for SinkCache. |
|
|
Args: |
|
|
model (`PreTrainedModel`): |
|
|
The model to generate from. |
|
|
""" |
|
|
|
|
|
cache_config = SQuatCacheConfig( |
|
|
backend=backend, |
|
|
nbits=nbits, |
|
|
quant_group_size=quant_group_size, |
|
|
residual_length=residual_length, |
|
|
squat_lambda=squat_lambda, |
|
|
subspace_dim=subspace_dim, |
|
|
shared_svd=shared_svd, |
|
|
) |
|
|
cache_class = SQUAT_BACKEND_CLASSES_MAPPING[cache_config.backend] |
|
|
|
|
|
if cache_config.backend == "quanto" and not is_optimum_quanto_available(): |
|
|
raise ImportError( |
|
|
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " |
|
|
"Please install it via with `pip install optimum-quanto`" |
|
|
) |
|
|
elif cache_config.backend == "HQQ" and not is_hqq_available(): |
|
|
raise ImportError( |
|
|
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " |
|
|
"Please install it via with `pip install hqq`" |
|
|
) |
|
|
|
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
|
raise ValueError("This custom generate function only works with decoder-only models") |
|
|
|
|
|
|
|
|
|
|
|
kwargs.pop("custom_generate", None) |
|
|
|
|
|
|
|
|
|
|
|
past_key_values = kwargs.pop("past_key_values", None) |
|
|
if past_key_values is None: |
|
|
past_key_values = cache_class(cache_config=cache_config) |
|
|
|
|
|
|
|
|
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
|
|
return generation_outputs |
|
|
|