import copy import importlib.metadata import json import os import warnings from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from packaging import version from transformers.utils import is_hqq_available, is_optimum_quanto_available, logging from transformers.cache_utils import CacheConfig, QuantizedCacheConfig, QuantizedCache if is_hqq_available(): from hqq.core.quantize import Quantizer as HQQQuantizer logger = logging.get_logger(__name__) @dataclass class SQuatCacheConfig(QuantizedCacheConfig): """ Configuration class for SQuat cache settings. """ def __init__(self, quant_group_size: Optional[int] = 64, squat_lambda: Optional[float] = 0.0001, subspace_dim: Optional[int] = 5, shared_svd: Optional[bool] = True, **kwargs, ): super().__init__(**kwargs) self.cache_implementation = "squat" self.quant_group_size = quant_group_size self.squat_lambda = squat_lambda self.subspace_dim = subspace_dim self.shared_svd = shared_svd class SQuatCache(QuantizedCache): """ Quantized Cache class that uses `SQuat` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. Parameters: cache_config (`SQuatCacheConfig`): A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. Example: ```python >>> # Run pip install quanto first if you don't have it yet >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SQuatCache, SQuatCacheConfig >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> cache_config = SQuatCacheConfig(nbits=4) >>> past_key_values = SQuatCache(cache_config=cache_config) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SQuatCache() ``` """ def __init__(self, cache_config: CacheConfig) -> None: super().__init__(cache_config) if is_optimum_quanto_available(): optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) if optimum_quanto_version <= version.parse("0.2.5"): raise ImportError( f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." ) from optimum.quanto import MaxOptimizer, qint2, qint4 if self.nbits not in [2, 4]: raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") if self.axis_key not in [0, -1]: raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") if self.axis_value not in [0, -1]: raise ValueError( f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" ) self.qtype = qint4 if self.nbits == 4 else qint2 self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization self.auxiliary_matrices_A = [] self.auxiliary_matrices_P = [] self.squat_lambda = getattr(cache_config, "squat_lambda", 0.0005) self.squat_q_group_size = getattr(cache_config, "quant_group_size", 64) self.squat_subspace_dim = getattr(cache_config, "subspace_dim", 20) self.squat_shared_svd = getattr(cache_config, "shared_svd", True) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] if len(self.key_cache) < layer_idx: raise ValueError("SQuatCache does not support model usage where layers are skipped. Use DynamicCache.") elif len(self.key_cache) == layer_idx: # prefilling if len(self.auxiliary_matrices_A) == layer_idx: Ainv_t, P_inv = self._get_query_subspace(key_states, cache_kwargs["query_states"], cache_kwargs["attention_mask"]) self.auxiliary_matrices_A.append(Ainv_t) self.auxiliary_matrices_P.append(P_inv) if key_states.shape[-2] % self.residual_length != 0: if key_states.shape[-2] < self.residual_length: key_states_quant = None key_states_full = key_states value_states_quant = None value_states_full = value_states else: key_states_quant = key_states[:, :, :-(key_states.shape[-2] % self.residual_length), :].contiguous() key_states_full = key_states[:, :, -(key_states.shape[-2] % self.residual_length):, :].contiguous() value_states_quant = value_states[:, :, :-(value_states.shape[-2] % self.residual_length), :].contiguous() value_states_full = value_states[:, :, -(value_states.shape[-2] % self.residual_length):, :].contiguous() else: key_states_quant = key_states key_states_full = None value_states_quant = value_states value_states_full = None if key_states_quant is not None: self._quantized_key_cache.append(self.squat_quantize_key(key_states_quant, self.squat_q_group_size, Ainv_t, P_inv)) self._quantized_value_cache.append(self._quantize(value_states_quant, axis=self.axis_value)) else: self._quantized_key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) self._quantized_value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) if key_states_full is not None: self.key_cache.append(key_states_full) self.value_cache.append(value_states_full) else: self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) keys_to_return, values_to_return = key_states, value_states else: # decoding if len(self._quantized_key_cache[layer_idx]) == 0: dequant_key = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) else: dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) if len(self._quantized_value_cache[layer_idx]) == 0: dequant_value = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) else: dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] keys_to_return = torch.cat(keys_to_return, dim=-2) values_to_return = torch.cat(values_to_return, dim=-2) if ( self.key_cache[layer_idx].dim() == 4 and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length ): keys_to_quantize = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) quantized_key = self.squat_quantize_key( keys_to_quantize, self.squat_q_group_size, self.auxiliary_matrices_A[layer_idx], self.auxiliary_matrices_P[layer_idx] ) self._quantized_key_cache[layer_idx] = self._quantize( torch.cat([dequant_key, self._dequantize(quantized_key)], dim=2), axis=self.axis_key ) self._quantized_value_cache[layer_idx] = self._quantize( values_to_return.contiguous(), axis=self.axis_value ) self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return keys_to_return, values_to_return def _get_query_subspace(self, key_states, query_states, attention_mask=None): bsz = query_states.shape[0] kv_nh = key_states.shape[1] head_dim = query_states.shape[3] num_key_value_groups = query_states.shape[1] // key_states.shape[1] subspace_dim = min(self.squat_subspace_dim, num_key_value_groups*key_states.shape[2]) # Get valid tokens from attention mask if attention_mask is not None: if attention_mask.shape[2] == attention_mask.shape[3]-1: attention_mask = attention_mask[:,:,:,:attention_mask.shape[2]] # Get last row of attention mask [bs, 1, seq_len] last_row_mask = attention_mask[:, :, -1, :] # Find valid token positions (where mask is 0) valid_tokens = (last_row_mask == 0).squeeze(1) # [bs, seq_len] # Only keep valid tokens for each batch query_subspace = [] for b in range(bsz): # Get valid tokens for this batch batch_valid = valid_tokens[b] # [seq_len] # Select valid tokens from query states batch_query = query_states[b] # [kv_nh, seq_len, head_dim] batch_valid_query = batch_query[:, batch_valid, :] # [kv_nh, valid_len, head_dim] valid_query_states_matrix = batch_valid_query.reshape(kv_nh, -1, head_dim) U, S, Vh = torch.linalg.svd(valid_query_states_matrix.float(), full_matrices=False) S_subspace = torch.diag_embed(S[:, :subspace_dim]).to(valid_query_states_matrix.dtype) Vh_subspace = Vh[:, :subspace_dim, :].to(valid_query_states_matrix.dtype) batch_query_subspace = torch.matmul(S_subspace, Vh_subspace) query_subspace.append(batch_query_subspace) if self.squat_shared_svd: break # Stack back into tensor query_subspace = torch.stack(query_subspace) # [bs, kv_nh, valid_len, head_dim] else: query_states_matrix = query_states.reshape(bsz, kv_nh, -1, head_dim) U, S, Vh = torch.linalg.svd(query_states_matrix.float(), full_matrices=False) #!!! float here might be suboptimal S_subspace = torch.diag_embed(S[:, :, :subspace_dim]).to(query_states_matrix.dtype) Vh_subspace = Vh[:, :, :subspace_dim, :].to(query_states_matrix.dtype) # dimension: [bs, nh, subspace_dim, head_dim] query_subspace = torch.matmul(S_subspace, Vh_subspace) if self.squat_shared_svd: query_subspace = query_subspace[0:1, ...] # Ainv_t is a list of matrices Ainv_t = self._generate_At_inv(self.squat_q_group_size, query_subspace.float(), lamb=self.squat_lambda) P_inv = torch.inverse(Ainv_t[-1]) return Ainv_t, P_inv def _generate_At_inv(self, quant_group_size, my_Qhat, lamb=1, tol=1e-7): """ Generate a list of T matrices where the t-th matrix has dimension (t*g, t*g). Parameters: - quant_group_size (int): Factor for matrix dimension scaling - lamb (float): Scaling factor for the final term - my_Qhat (torch.Tensor): A matrix of size (d, d) Returns: - List[torch.Tensor]: List of int(head_dim/quant_group_size) matrices """ bs, kv_nh, subspace_dim, head_dim = my_Qhat.shape T = (head_dim+quant_group_size-1)//quant_group_size matrices = [None] * T device = my_Qhat.device I = torch.eye(head_dim, device=device) # Initialize A_T A_T = I.expand(bs, kv_nh, head_dim, head_dim) + lamb * torch.matmul( my_Qhat.transpose(-1, -2), my_Qhat ) matrices[T - 1] = A_T for t in range(T - 1, 0, -1): # Recursive computation of A_{t} from A_{t+1} current_dim = t * quant_group_size # Extract M_{t+1}, N_{t+1}, and O_{t+1} M_t1 = A_T[:, :, :current_dim, :current_dim] # Top-left square matrix N_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, :current_dim] # Bottom-left matrix O_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, current_dim : current_dim + quant_group_size] # Bottom-right square matrix # Compute A_t I_mat = torch.eye(quant_group_size, device=device) O_t1_inv = torch.inverse(O_t1 + tol * I_mat.expand(bs, kv_nh, quant_group_size, quant_group_size)) A_t = M_t1 - torch.matmul(N_t1.transpose(-1, -2), torch.matmul(O_t1_inv, N_t1)) matrices[t - 1] = A_t[:, :, :, -quant_group_size:] # Update A_T for the next iteration A_T = A_t return matrices def squat_quantize_key(self, key_states, quant_group_size, Ainv_t, P_inv): bsz, nh, seq_len, hidden_dim = key_states.shape dtype = key_states.dtype T = (hidden_dim+quant_group_size-1)//quant_group_size key_states_dequant = [] group = key_states # Extract the group for i in range(T): key_states_quant_this_quant_group = self._quantize( group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size].contiguous(), axis=self.axis_key ) dequantized = self._dequantize(key_states_quant_this_quant_group) if i < T - 1: d_vec = ( dequantized - group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size] ).float() H_t = Ainv_t[i] B_t = P_inv[ :, :, (i + 1) * quant_group_size :, : (i + 1) * quant_group_size ] update = torch.matmul( torch.matmul(d_vec, H_t.transpose(-2, -1)), B_t.transpose(-2, -1) ) group[:, :, :, (i + 1) * quant_group_size :] = ( group[:, :, :, (i + 1) * quant_group_size :] + update ) key_states_dequant.append(dequantized) key_states_dequant = torch.cat(key_states_dequant, dim=3) key_states_quant = self._quantize(key_states_dequant, axis=self.axis_key) return key_states_quant class QuantoSQuatCache(SQuatCache): def __init__(self, cache_config: CacheConfig) -> None: super().__init__(cache_config) if is_optimum_quanto_available(): optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) if optimum_quanto_version <= version.parse("0.2.5"): raise ImportError( f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." ) from optimum.quanto import MaxOptimizer, qint2, qint4 if self.nbits not in [2, 4]: raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") if self.axis_key not in [0, -1]: raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") if self.axis_value not in [0, -1]: raise ValueError( f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" ) self.qtype = qint4 if self.nbits == 4 else qint2 self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization def _quantize(self, tensor, axis): # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore if is_optimum_quanto_available(): from optimum.quanto import quantize_weight scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) return qtensor def _dequantize(self, qtensor): return qtensor.dequantize() class HQQSQuatCache(SQuatCache): def __init__(self, cache_config: CacheConfig) -> None: super().__init__(cache_config) if self.nbits not in [1, 2, 3, 4, 8]: raise ValueError( f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" ) if self.axis_key not in [0, 1]: raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") if self.axis_value not in [0, 1]: raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") self.quantizer = HQQQuantizer def _quantize(self, tensor, axis): qtensor, meta = self.quantizer.quantize( tensor, axis=axis, device=self.device, compute_dtype=self.compute_dtype, nbits=self.nbits, group_size=self.q_group_size, ) meta["compute_dtype"] = self.compute_dtype self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype meta["scale"] = meta["scale"].to(qtensor.device) meta["zero"] = meta["zero"].to(qtensor.device) return qtensor, meta def _dequantize(self, qtensor): quant_tensor, meta = qtensor tensor = self.quantizer.dequantize(quant_tensor, meta) return tensor SQUAT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoSQuatCache, "HQQ": HQQSQuatCache} def generate(model, generation_config=None, backend="quanto", nbits=2, quant_group_size=64, residual_length=32, squat_lambda=0.001, subspace_dim=20, shared_svd=True, **kwargs): """Custom generate function for SinkCache. Args: model (`PreTrainedModel`): The model to generate from. """ cache_config = SQuatCacheConfig( backend=backend, nbits=nbits, quant_group_size=quant_group_size, residual_length=residual_length, squat_lambda=squat_lambda, subspace_dim=subspace_dim, shared_svd=shared_svd, ) cache_class = SQUAT_BACKEND_CLASSES_MAPPING[cache_config.backend] if cache_config.backend == "quanto" and not is_optimum_quanto_available(): raise ImportError( "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " "Please install it via with `pip install optimum-quanto`" ) elif cache_config.backend == "HQQ" and not is_hqq_available(): raise ImportError( "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " "Please install it via with `pip install hqq`" ) # 1.b. The model must be decoder-only if model.config.is_encoder_decoder: raise ValueError("This custom generate function only works with decoder-only models") # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53. kwargs.pop("custom_generate", None) # 2. Generate with SinkCache # 2.a. prepare the cache, if it was not passed. past_key_values = kwargs.pop("past_key_values", None) if past_key_values is None: past_key_values = cache_class(cache_config=cache_config) # 2.b. generate with the cache generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) return generation_outputs