| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import logging |
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | from compressed_tensors.modeling import ( |
| | IMPL_ATTR, |
| | KV_CACHE_ATTR, |
| | QuantizedAttentionImpl, |
| | QuantizedKVCache, |
| | ) |
| | from compressed_tensors.quantization import ( |
| | ActivationOrdering, |
| | DynamicType, |
| | QuantizationArgs, |
| | QuantizationMetadata, |
| | QuantizationScheme, |
| | QuantizationStatus, |
| | QuantizationStrategy, |
| | ) |
| | from compressed_tensors.quantization.lifecycle.forward import ( |
| | wrap_module_forward_quantized, |
| | ) |
| | from compressed_tensors.quantization.utils import strategy_cdiv |
| | from compressed_tensors.utils import ( |
| | disable_hf_hook, |
| | get_execution_device, |
| | get_head_dim, |
| | get_num_attn_heads, |
| | get_num_kv_heads, |
| | register_offload_parameter, |
| | ) |
| | from torch.nn import Module, Parameter |
| |
|
| |
|
| | __all__ = [ |
| | "initialize_module_for_quantization", |
| | "is_attention_module", |
| | "initialize_qparams", |
| | "initialize_attn_qparams", |
| | ] |
| |
|
| |
|
| | _LOGGER = logging.getLogger(__name__) |
| |
|
| |
|
| | def initialize_module_for_quantization( |
| | module: Module, |
| | scheme: Optional[QuantizationScheme] = None, |
| | force_zero_point: bool = True, |
| | ): |
| | """ |
| | Attaches appropriate scales, zero points, and observers to a layer |
| | given its target quantization scheme. |
| | |
| | Previously initialized scales and zero points will be removed from |
| | module if they no longer apply to the scheme |
| | |
| | :param module: module to set for calibration |
| | :param scheme: scheme to use for quantization. if None is provided, |
| | will attempt to use scheme stored in the module under `quantization_scheme`, |
| | if not provided, the layer will be skipped |
| | :param force_zero_point: whether to force initialization of a zero point for |
| | symmetric quantization |
| | """ |
| | scheme = scheme or getattr(module, "quantization_scheme", None) |
| | if scheme is None: |
| | return |
| |
|
| | QuantizationMetadata.clear_all_qparams(module) |
| |
|
| | if is_attention_module(module): |
| | |
| | initialize_attn_qparams(module, scheme, force_zero_point) |
| |
|
| | else: |
| | if not isinstance(module, torch.nn.Linear): |
| | _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") |
| |
|
| | |
| | if hasattr(module, "weight"): |
| | weight = module.weight |
| | assert isinstance(weight, torch.Tensor) |
| | else: |
| | |
| | |
| | _LOGGER.warning( |
| | f"module type {type(module)} targeted for quantization but " |
| | f"has no attribute weight, skipping quantization for {type(module)}" |
| | ) |
| | return |
| |
|
| | if scheme.input_activations is not None: |
| | initialize_qparams( |
| | module, |
| | "input", |
| | scheme.input_activations, |
| | observed_shape=weight.shape[-1:], |
| | observed_dtype=weight.dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| |
|
| | if scheme.weights is not None: |
| | initialize_qparams( |
| | module, |
| | "weight", |
| | scheme.weights, |
| | observed_shape=weight.shape, |
| | observed_dtype=weight.dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| |
|
| | if scheme.output_activations is not None: |
| | initialize_qparams( |
| | module, |
| | "output", |
| | scheme.output_activations, |
| | observed_shape=weight.shape[:-1], |
| | observed_dtype=weight.dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| |
|
| | with disable_hf_hook(module): |
| | |
| | |
| | wrap_module_forward_quantized(module, scheme) |
| |
|
| | module.quantization_scheme = scheme |
| | module.quantization_status = QuantizationStatus.INITIALIZED |
| |
|
| |
|
| | def is_attention_module(module: Module): |
| | return "attention" in module.__class__.__name__.lower() and ( |
| | hasattr(module, "k_proj") |
| | or hasattr(module, "v_proj") |
| | or hasattr(module, "qkv_proj") |
| | ) |
| |
|
| |
|
| | def initialize_qparams( |
| | module: Module, |
| | base_name: str, |
| | quantization_args: QuantizationArgs, |
| | observed_shape: Tuple[Union[int, None]], |
| | observed_dtype: torch.dtype, |
| | force_zero_point: bool = True, |
| | ): |
| | """ |
| | Initialize quantization parameters for a given basename according to the passed |
| | quantization args. The shape and dtype of the observed weight/activation must also |
| | be provided. |
| | |
| | Scales will always be initialized. Global scales are initialized depending on args. |
| | Zero points will be initialized if not symmetric or if `force_zero_point` is True. |
| | |
| | :param module: module to register qparams to |
| | :param base_name: base name of qparams, for example "input", "weight", "k", "v" |
| | :param quantization_args: arguments for quantization |
| | :param observed_shape: last (right-most) known dimensions of the observed weight/act |
| | :param observed_dtype: dtype of the observed weight/actt |
| | :param force_zero_point: force the zero_point parameter to be initialized |
| | """ |
| | strategy = quantization_args.strategy |
| | dynamic = quantization_args.dynamic |
| | actorder = quantization_args.actorder |
| | device = get_execution_device(module) |
| |
|
| | |
| | if dynamic is True: |
| | return |
| |
|
| | |
| | if strategy == QuantizationStrategy.TENSOR_GROUP: |
| | init_global_scale = Parameter( |
| | torch.empty(1, dtype=torch.float32, device=device), |
| | requires_grad=False, |
| | ) |
| | register_offload_parameter( |
| | module, f"{base_name}_global_scale", init_global_scale |
| | ) |
| |
|
| | |
| | if dynamic == DynamicType.LOCAL: |
| | return |
| |
|
| | |
| | if strategy == QuantizationStrategy.TENSOR: |
| | expected_shape = (1,) |
| |
|
| | elif strategy == QuantizationStrategy.TOKEN: |
| | raise ValueError("Cannot perform static token quantization") |
| |
|
| | elif strategy == QuantizationStrategy.CHANNEL: |
| | if len(observed_shape) < 2: |
| | raise ValueError("Channel quant requires at least 2 observed dimensions") |
| |
|
| | expected_shape = (observed_shape[-2], 1) |
| |
|
| | elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): |
| | assert quantization_args.group_size is not None |
| | if len(observed_shape) < 1: |
| | raise ValueError("Group quant requires at least 1 observed dimension") |
| |
|
| | group_size = quantization_args.group_size |
| | num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) |
| | expected_shape = (*observed_shape[:-1], num_groups) |
| |
|
| | |
| | if actorder == ActivationOrdering.GROUP: |
| | init_g_idx = Parameter( |
| | torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), |
| | requires_grad=False, |
| | ) |
| | register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) |
| |
|
| | elif strategy == QuantizationStrategy.BLOCK: |
| | assert quantization_args.block_structure is not None |
| | if len(observed_shape) < 2: |
| | raise ValueError("Block quant requires at least 2 observed dimensions") |
| |
|
| | block_structure = quantization_args.block_structure |
| | num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) |
| | num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) |
| | expected_shape = (num_rows, num_cols) |
| |
|
| | elif strategy == QuantizationStrategy.ATTN_HEAD: |
| | |
| | if len(observed_shape) < 3: |
| | raise ValueError("Attention quant requires at least 3 observed dimensions") |
| |
|
| | expected_shape = (observed_shape[-3], 1, 1) |
| |
|
| | else: |
| | assert False, f"Unknown strategy {strategy}" |
| |
|
| | |
| | scale_dtype = observed_dtype |
| | if scale_dtype not in [ |
| | torch.float16, |
| | torch.bfloat16, |
| | torch.float32, |
| | torch.float64, |
| | ]: |
| | scale_dtype = torch.float16 |
| |
|
| | |
| | init_scale = Parameter( |
| | torch.empty(expected_shape, dtype=scale_dtype, device=device), |
| | requires_grad=False, |
| | ) |
| | register_offload_parameter(module, f"{base_name}_scale", init_scale) |
| |
|
| | if force_zero_point or not quantization_args.symmetric: |
| | init_zero_point = Parameter( |
| | torch.zeros( |
| | expected_shape, device=device, dtype=quantization_args.zp_dtype |
| | ), |
| | requires_grad=False, |
| | ) |
| | register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) |
| |
|
| |
|
| | def initialize_attn_qparams( |
| | module: Module, scheme: QuantizationScheme, force_zero_point: bool |
| | ): |
| | """Initlaize k_scale, v_scale for self_attn""" |
| |
|
| | impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None) |
| | kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None) |
| |
|
| | if impl is None and kv_cache is None: |
| | raise ValueError( |
| | f"Attention module has quantization scheme but no {IMPL_ATTR} " |
| | f"or {KV_CACHE_ATTR} attributes. Please ensure that these " |
| | "attributes are initialized using `apply_quantization_config`." |
| | ) |
| |
|
| | _validate_attention_scheme(scheme) |
| |
|
| | |
| | config = kv_cache.config |
| | num_attn_heads = get_num_attn_heads(config) |
| | num_kv_heads = get_num_kv_heads(config) |
| | head_dim = get_head_dim(config) |
| |
|
| | |
| | q_observed_shape = (num_attn_heads, None, head_dim) |
| | kv_observed_shape = (num_kv_heads, None, head_dim) |
| | observed_dtype = next(module.parameters()).dtype |
| |
|
| | if impl is not None: |
| | initialize_qparams( |
| | module, |
| | "q", |
| | scheme.input_activations, |
| | observed_shape=q_observed_shape, |
| | observed_dtype=observed_dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| |
|
| | if kv_cache is not None: |
| | initialize_qparams( |
| | module, |
| | "k", |
| | scheme.input_activations, |
| | observed_shape=kv_observed_shape, |
| | observed_dtype=observed_dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| | initialize_qparams( |
| | module, |
| | "v", |
| | scheme.input_activations, |
| | observed_shape=kv_observed_shape, |
| | observed_dtype=observed_dtype, |
| | force_zero_point=force_zero_point, |
| | ) |
| |
|
| |
|
| | def _validate_attention_scheme(scheme: QuantizationScheme): |
| | if scheme.weights is not None: |
| | raise ValueError( |
| | "Cannot apply weight quantization to attention. " |
| | "Instead, target the (q|k|v)_proj submodule layers of attention" |
| | ) |
| |
|
| | if scheme.input_activations is None: |
| | raise ValueError( |
| | "Cannot apply attention quantization without specifying input activations" |
| | ) |
| |
|
| | if scheme.output_activations is not None: |
| | raise ValueError("Cannot apply output quantization to attention") |
| |
|