Spaces:
Running
Running
| """ | |
| Role-aware ternary quantization utilities for small language models. | |
| This module contains two layers of functionality: | |
| 1. A practical mixed-precision recipe that keeps fragile merge projections | |
| at full precision and ternarizes the more robust input projections. | |
| 2. A paper-oriented planner that turns that recipe into an explicit budgeted | |
| allocator with per-module sensitivity scores and per-module sparse residual | |
| budgets. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import asdict, dataclass, field | |
| from typing import Optional | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ternary_quant.pipeline import ( | |
| ActivationCapture, | |
| QuantizationConfig, | |
| _get_decoder_layers, | |
| _get_weight, | |
| _is_linear_layer, | |
| _run_model_forward, | |
| _set_weight, | |
| _should_quantize, | |
| ) | |
| ROLE_ORDER = ( | |
| "attention_inputs", | |
| "attention_output", | |
| "mlp_inputs", | |
| "mlp_output", | |
| ) | |
| class ModuleIOCapture: | |
| """Hook-based capture for flattened module inputs and outputs.""" | |
| def __init__(self): | |
| self.inputs = [] | |
| self.outputs = [] | |
| self.hook = None | |
| def register(self, module: nn.Module) -> None: | |
| self.hook = module.register_forward_hook(self._hook_fn) | |
| def _hook_fn(self, module, input, output) -> None: | |
| inp = input[0].detach() | |
| if inp.dim() == 3: | |
| inp = inp.reshape(-1, inp.shape[-1]) | |
| out = output[0] if isinstance(output, tuple) else output | |
| out = out.detach() | |
| if out.dim() == 3: | |
| out = out.reshape(-1, out.shape[-1]) | |
| self.inputs.append(inp.cpu()) | |
| self.outputs.append(out.cpu()) | |
| def get_inputs(self) -> Optional[torch.Tensor]: | |
| if not self.inputs: | |
| return None | |
| return torch.cat(self.inputs, dim=0) | |
| def get_outputs(self) -> Optional[torch.Tensor]: | |
| if not self.outputs: | |
| return None | |
| return torch.cat(self.outputs, dim=0) | |
| def remove(self) -> None: | |
| if self.hook is not None: | |
| self.hook.remove() | |
| self.hook = None | |
| self.inputs = [] | |
| self.outputs = [] | |
| class GroupwiseTernaryParameter: | |
| """Grouped asymmetric ternary representation: alpha_g * T + mu_g + sparse.""" | |
| ternary_codes: torch.Tensor | |
| group_alpha: torch.Tensor | |
| group_mu: torch.Tensor | |
| group_size: int | |
| sparse_indices: Optional[torch.Tensor] | |
| sparse_residual: Optional[torch.Tensor] | |
| lr_U: Optional[torch.Tensor] | |
| lr_V: Optional[torch.Tensor] | |
| original_shape: tuple[int, int] | |
| original_dtype: torch.dtype | |
| def dequantize(self) -> torch.Tensor: | |
| out_features, in_features = self.original_shape | |
| alpha = self.group_alpha.float().repeat_interleave(self.group_size, dim=1)[ | |
| :, :in_features | |
| ] | |
| mu = self.group_mu.float().repeat_interleave(self.group_size, dim=1)[ | |
| :, :in_features | |
| ] | |
| weight = alpha * self.ternary_codes.float() + mu | |
| if self.sparse_indices is not None and self.sparse_residual is not None: | |
| flat = weight.reshape(-1) | |
| flat[self.sparse_indices.long()] += self.sparse_residual.float() | |
| if self.lr_U is not None and self.lr_V is not None: | |
| weight = weight + self.lr_U.float() @ self.lr_V.float() | |
| return weight | |
| def num_params(self) -> int: | |
| return self.original_shape[0] * self.original_shape[1] | |
| def sparse_nnz(self) -> int: | |
| if self.sparse_indices is None: | |
| return 0 | |
| return int(self.sparse_indices.numel()) | |
| def effective_bits(self) -> float: | |
| out_features, in_features = self.original_shape | |
| n_groups = self.group_alpha.shape[1] | |
| code_bits = 2 * out_features * in_features | |
| group_param_bits = 16 * 2 * out_features * n_groups | |
| sparse_bits = self.sparse_nnz * (32 + 16) | |
| low_rank_bits = 0 | |
| if self.lr_U is not None and self.lr_V is not None: | |
| rank = self.lr_U.shape[1] | |
| low_rank_bits = 16 * rank * (out_features + in_features) | |
| return (code_bits + group_param_bits + sparse_bits) / ( | |
| out_features * in_features | |
| ) + low_rank_bits / (out_features * in_features) | |
| def compute_groupwise_error( | |
| weight: torch.Tensor, | |
| param: GroupwiseTernaryParameter, | |
| ) -> dict: | |
| """Compute reconstruction metrics for grouped ternary weights.""" | |
| weight_f = weight.float().cpu() | |
| dequant = param.dequantize().float().cpu() | |
| diff = weight_f - dequant | |
| mse = diff.square().mean().item() | |
| rmse = mse**0.5 | |
| rms_weight = weight_f.norm().item() / (weight_f.numel() ** 0.5) + 1e-8 | |
| rel_error = rmse / rms_weight | |
| sparsity = (param.ternary_codes == 0).float().mean().item() | |
| return { | |
| "mse": mse, | |
| "rmse": rmse, | |
| "relative_error": rel_error, | |
| "max_error": diff.abs().max().item(), | |
| "sparsity": sparsity, | |
| "effective_bits": param.effective_bits, | |
| "sparse_nnz": param.sparse_nnz, | |
| } | |
| class GroupwiseAsymmetricTernaryQuantizer: | |
| """Per-group asymmetric ternary fitting with optional sparse residual.""" | |
| def __init__( | |
| self, | |
| group_size: int = 32, | |
| n_iter: int = 10, | |
| use_activation_aware: bool = True, | |
| salient_fraction: float = 0.0, | |
| low_rank_rank: int = 0, | |
| low_rank_fit_mode: str = "weight_svd", | |
| low_rank_ridge: float = 1e-4, | |
| low_rank_max_samples: int = 4096, | |
| importance_threshold_scale: float = 0.0, | |
| ): | |
| self.group_size = group_size | |
| self.n_iter = n_iter | |
| self.use_activation_aware = use_activation_aware | |
| self.salient_fraction = max(0.0, min(salient_fraction, 1.0)) | |
| self.low_rank_rank = max(0, int(low_rank_rank)) | |
| self.low_rank_fit_mode = low_rank_fit_mode | |
| self.low_rank_ridge = max(0.0, float(low_rank_ridge)) | |
| self.low_rank_max_samples = max(0, int(low_rank_max_samples)) | |
| self.importance_threshold_scale = float(importance_threshold_scale) | |
| def quantize( | |
| self, | |
| weight: torch.Tensor, | |
| activations: Optional[torch.Tensor] = None, | |
| outputs: Optional[torch.Tensor] = None, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> GroupwiseTernaryParameter: | |
| weight_f = weight.float() | |
| work_device = weight_f.device | |
| out_features, in_features = weight_f.shape | |
| group_size = min(self.group_size, in_features) | |
| n_groups = (in_features + group_size - 1) // group_size | |
| activations_f = None | |
| if activations is not None: | |
| activations_f = activations.float().to(work_device) | |
| outputs_f = None | |
| if outputs is not None: | |
| outputs_f = outputs.float().to(work_device) | |
| act_weights = None | |
| act_importance = None | |
| if activations_f is not None and self.use_activation_aware: | |
| act_weights = (activations_f ** 2).mean(dim=0) | |
| act_weights = act_weights / (act_weights.mean() + 1e-8) | |
| if activations_f is not None and self.importance_threshold_scale > 0: | |
| # AWQ-inspired: mean |activation| per input channel → importance scores | |
| act_importance = activations_f.abs().mean(dim=0) | |
| act_importance = act_importance / (act_importance.mean() + 1e-8) | |
| ternary_codes = torch.zeros_like(weight_f, dtype=torch.int8) | |
| group_alpha = torch.zeros( | |
| out_features, | |
| n_groups, | |
| dtype=torch.float32, | |
| device=work_device, | |
| ) | |
| group_mu = torch.zeros( | |
| out_features, | |
| n_groups, | |
| dtype=torch.float32, | |
| device=work_device, | |
| ) | |
| for group_idx in range(n_groups): | |
| start = group_idx * group_size | |
| end = min(start + group_size, in_features) | |
| weight_group = weight_f[:, start:end] | |
| group_weights = ( | |
| act_weights[start:end] if act_weights is not None else None | |
| ) | |
| group_importance = ( | |
| act_importance[start:end] if act_importance is not None else None | |
| ) | |
| alpha, mu, ternary = self._fit_group( | |
| weight_group, group_weights, group_importance, self.importance_threshold_scale | |
| ) | |
| ternary_codes[:, start:end] = ternary.to(torch.int8) | |
| group_alpha[:, group_idx] = alpha.squeeze(1) | |
| group_mu[:, group_idx] = mu.squeeze(1) | |
| sparse_indices = None | |
| sparse_residual = None | |
| lr_U = None | |
| lr_V = None | |
| if self.salient_fraction > 0.0: | |
| base = self._dequantize_from_parts( | |
| ternary_codes, | |
| group_alpha, | |
| group_mu, | |
| group_size, | |
| in_features, | |
| ) | |
| sparse_indices, sparse_residual = self._select_salient_residual( | |
| weight_f, | |
| base, | |
| act_weights, | |
| ) | |
| else: | |
| base = self._dequantize_from_parts( | |
| ternary_codes, | |
| group_alpha, | |
| group_mu, | |
| group_size, | |
| in_features, | |
| ) | |
| if sparse_indices is not None and sparse_residual is not None: | |
| flat = base.reshape(-1) | |
| flat[sparse_indices.long()] += sparse_residual.float() | |
| if self.low_rank_rank > 0: | |
| if ( | |
| self.low_rank_fit_mode == "activation_regression" | |
| and activations_f is not None | |
| and outputs_f is not None | |
| ): | |
| lr_U, lr_V = self._fit_output_aware_low_rank_residual( | |
| activations=activations_f, | |
| outputs=outputs_f, | |
| base_weight=base, | |
| bias=bias, | |
| ) | |
| if lr_U is None or lr_V is None: | |
| lr_U, lr_V = self._fit_low_rank_residual(weight_f - base) | |
| return GroupwiseTernaryParameter( | |
| ternary_codes=ternary_codes.cpu(), | |
| group_alpha=group_alpha.to(torch.float16).cpu(), | |
| group_mu=group_mu.to(torch.float16).cpu(), | |
| group_size=group_size, | |
| sparse_indices=sparse_indices.cpu() if sparse_indices is not None else None, | |
| sparse_residual=sparse_residual.cpu() if sparse_residual is not None else None, | |
| lr_U=lr_U.cpu() if lr_U is not None else None, | |
| lr_V=lr_V.cpu() if lr_V is not None else None, | |
| original_shape=tuple(weight.shape), | |
| original_dtype=weight.dtype, | |
| ) | |
| def _fit_group( | |
| self, | |
| weight_group: torch.Tensor, | |
| act_weights: Optional[torch.Tensor], | |
| act_importance: Optional[torch.Tensor] = None, | |
| importance_threshold_scale: float = 0.0, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| if act_weights is not None: | |
| weights = act_weights.unsqueeze(0) | |
| mu = (weight_group * weights).sum(dim=1, keepdim=True) / ( | |
| act_weights.sum() + 1e-8 | |
| ) | |
| alpha = ((weight_group - mu).abs() * weights).sum(dim=1, keepdim=True) / ( | |
| act_weights.sum() + 1e-8 | |
| ) | |
| else: | |
| weights = torch.ones(1, weight_group.shape[1], device=weight_group.device) | |
| mu = weight_group.mean(dim=1, keepdim=True) | |
| alpha = (weight_group - mu).abs().mean(dim=1, keepdim=True) | |
| # Clamp both ends: min to avoid div-by-zero, max to avoid FP16 overflow | |
| # (inf in plane-0 scales causes inf-inf=NaN in residuals for planes 1+) | |
| alpha = alpha.clamp(min=1e-8, max=65504.0) | |
| ternary = self._round_to_ternary(weight_group, alpha, mu, act_importance, importance_threshold_scale) | |
| for _ in range(self.n_iter): | |
| alpha, mu = self._solve_alpha_mu(weight_group, ternary, weights) | |
| ternary = self._round_to_ternary(weight_group, alpha, mu, act_importance, importance_threshold_scale) | |
| return alpha, mu, ternary | |
| def _round_to_ternary( | |
| weight_group: torch.Tensor, | |
| alpha: torch.Tensor, | |
| mu: torch.Tensor, | |
| act_importance: Optional[torch.Tensor] = None, | |
| importance_threshold_scale: float = 0.0, | |
| ) -> torch.Tensor: | |
| normalized = (weight_group - mu) / alpha | |
| if act_importance is not None and importance_threshold_scale > 0: | |
| # AWQ-inspired: high-importance input channels get lower threshold, | |
| # making them less likely to be zeroed out. Preserves signal where activations | |
| # are large — exactly the insight from AWQ applied to ternary codes. | |
| # thresh shape: [1, group_size], broadcast over output rows. | |
| thresh = 0.5 / ( | |
| 1.0 + act_importance.unsqueeze(0) * importance_threshold_scale | |
| ).clamp(max=4.0) # minimum threshold = 0.5/4.0 = 0.125 | |
| else: | |
| thresh = 0.5 | |
| ternary = torch.zeros_like(normalized) | |
| ternary[normalized > thresh] = 1.0 | |
| ternary[normalized < -thresh] = -1.0 | |
| return ternary | |
| def _solve_alpha_mu( | |
| weight_group: torch.Tensor, | |
| ternary: torch.Tensor, | |
| weights: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| out_features = weight_group.shape[0] | |
| weighted_t = weights * ternary | |
| weighted_t2 = weighted_t * ternary | |
| a11 = weighted_t2.sum(dim=1) | |
| a12 = weighted_t.sum(dim=1) | |
| a22 = weights.sum(dim=1).expand(out_features) | |
| b1 = (weighted_t * weight_group).sum(dim=1) | |
| b2 = (weights * weight_group).sum(dim=1) | |
| det = (a11 * a22 - a12.square()).clamp(min=1e-10) | |
| alpha = ((a22 * b1 - a12 * b2) / det).clamp(min=1e-8, max=65504.0).unsqueeze(1) | |
| mu = ((a11 * b2 - a12 * b1) / det).clamp(min=-65504.0, max=65504.0).unsqueeze(1) | |
| return alpha, mu | |
| def _dequantize_from_parts( | |
| self, | |
| ternary_codes: torch.Tensor, | |
| group_alpha: torch.Tensor, | |
| group_mu: torch.Tensor, | |
| group_size: int, | |
| in_features: int, | |
| ) -> torch.Tensor: | |
| alpha = group_alpha.repeat_interleave(group_size, dim=1)[:, :in_features] | |
| mu = group_mu.repeat_interleave(group_size, dim=1)[:, :in_features] | |
| return alpha * ternary_codes.float() + mu | |
| def _select_salient_residual( | |
| self, | |
| weight: torch.Tensor, | |
| base: torch.Tensor, | |
| act_weights: Optional[torch.Tensor], | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| residual = weight - base | |
| if act_weights is not None: | |
| sensitivity = residual.abs() * act_weights.unsqueeze(0) | |
| else: | |
| sensitivity = residual.abs() | |
| k = max(1, int(self.salient_fraction * sensitivity.numel())) | |
| indices = torch.topk(sensitivity.reshape(-1), k).indices.to(torch.int32) | |
| residual_values = residual.reshape(-1)[indices.long()].to(torch.float16) | |
| return indices, residual_values | |
| def _fit_low_rank_residual( | |
| self, | |
| residual: torch.Tensor, | |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| return self._factorize_low_rank(residual) | |
| def _fit_output_aware_low_rank_residual( | |
| self, | |
| activations: torch.Tensor, | |
| outputs: torch.Tensor, | |
| base_weight: torch.Tensor, | |
| bias: Optional[torch.Tensor], | |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| if activations is None or outputs is None: | |
| return None, None | |
| x = activations.float() | |
| y = outputs.float() | |
| if x.dim() != 2 or y.dim() != 2 or x.shape[0] != y.shape[0]: | |
| return None, None | |
| if self.low_rank_max_samples > 0 and x.shape[0] > self.low_rank_max_samples: | |
| indices = torch.linspace( | |
| 0, | |
| x.shape[0] - 1, | |
| steps=self.low_rank_max_samples, | |
| dtype=torch.float32, | |
| ).round().long().unique() | |
| x = x.index_select(0, indices) | |
| y = y.index_select(0, indices) | |
| base_bias = bias.float() if bias is not None else None | |
| base_outputs = F.linear(x, base_weight.float(), base_bias) | |
| residual_outputs = y - base_outputs | |
| if residual_outputs.abs().max().item() < 1e-8: | |
| return None, None | |
| xtx = x.T @ x | |
| rhs = x.T @ residual_outputs | |
| ridge_scale = xtx.diagonal().mean().item() if xtx.numel() else 0.0 | |
| ridge = self.low_rank_ridge * max(ridge_scale, 1e-8) | |
| system = xtx + ridge * torch.eye(xtx.shape[0], dtype=xtx.dtype) | |
| try: | |
| delta_t = torch.linalg.solve(system, rhs) | |
| except RuntimeError: | |
| delta_t = torch.linalg.pinv(system) @ rhs | |
| delta = delta_t.T.contiguous() | |
| return self._factorize_low_rank(delta) | |
| def _factorize_low_rank( | |
| self, | |
| residual: torch.Tensor, | |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| out_features, in_features = residual.shape | |
| rank = min(self.low_rank_rank, out_features, in_features) | |
| if rank <= 0: | |
| return None, None | |
| try: | |
| if out_features >= in_features: | |
| residual_t = residual.T.contiguous() | |
| if min(residual_t.shape) > rank + 4: | |
| q = min(rank + 8, min(residual_t.shape)) | |
| U, S, V = torch.svd_lowrank(residual_t, q=q, niter=2) | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| V = V[:, :rank] | |
| lr_U = (V * S.unsqueeze(0)).to(torch.float16) | |
| lr_V = U.T.to(torch.float16) | |
| return lr_U, lr_V | |
| U, S, Vh = torch.linalg.svd(residual_t, full_matrices=False) | |
| lr_U = (Vh[:rank, :].T * S[:rank].unsqueeze(0)).to(torch.float16) | |
| lr_V = U[:, :rank].T.to(torch.float16) | |
| return lr_U, lr_V | |
| if min(residual.shape) > rank + 4: | |
| q = min(rank + 8, min(residual.shape)) | |
| U, S, V = torch.svd_lowrank(residual, q=q, niter=2) | |
| U = U[:, :rank] | |
| S = S[:rank] | |
| V = V[:, :rank] | |
| lr_U = (U * S.unsqueeze(0)).to(torch.float16) | |
| lr_V = V.T.to(torch.float16) | |
| return lr_U, lr_V | |
| U, S, Vh = torch.linalg.svd(residual, full_matrices=False) | |
| except RuntimeError: | |
| return None, None | |
| lr_U = (U[:, :rank] * S[:rank].unsqueeze(0)).to(torch.float16) | |
| lr_V = Vh[:rank, :].to(torch.float16) | |
| return lr_U, lr_V | |
| class TrainableLowRankLinear(nn.Module): | |
| """Frozen quantized base plus trainable low-rank residual.""" | |
| def __init__( | |
| self, | |
| quant_param: GroupwiseTernaryParameter, | |
| bias: Optional[torch.Tensor] = None, | |
| ): | |
| super().__init__() | |
| if quant_param.lr_U is None or quant_param.lr_V is None: | |
| raise ValueError("Trainable low-rank wrapper requires initialized lr_U/lr_V.") | |
| base_weight = quant_param.dequantize().float() - ( | |
| quant_param.lr_U.float() @ quant_param.lr_V.float() | |
| ) | |
| self.register_buffer("base_weight", base_weight) | |
| self.U = nn.Parameter(quant_param.lr_U.float().clone()) | |
| self.V = nn.Parameter(quant_param.lr_V.float().clone()) | |
| if bias is not None: | |
| self.register_buffer("bias", bias.float().clone()) | |
| else: | |
| self.bias = None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| weight = self.base_weight.to(x.device, dtype=x.dtype) | |
| delta = (self.U @ self.V).to(x.device, dtype=x.dtype) | |
| bias = self.bias.to(x.device, dtype=x.dtype) if self.bias is not None else None | |
| return F.linear(x, weight + delta, bias) | |
| class SmallModelQuantizationConfig: | |
| """Configuration for role-aware small-model ternarization.""" | |
| group_size: int = 32 | |
| n_iter: int = 10 | |
| salient_fraction: float = 0.01 | |
| min_salient_fraction: float = 0.0025 | |
| max_salient_fraction: float = 0.01 | |
| adaptive_salient: bool = False | |
| low_rank_rank: int = 0 | |
| adaptive_low_rank: bool = False | |
| low_rank_chunk_rank: int = 16 | |
| low_rank_target_average_bits: Optional[float] = None | |
| low_rank_fit_mode: str = "weight_svd" | |
| low_rank_ridge: float = 1e-4 | |
| low_rank_max_samples: int = 4096 | |
| n_boundary_layers: int = 2 | |
| calibration_batch_size: int = 4 | |
| quantize_attention_output: bool = False | |
| quantize_mlp_output: bool = False | |
| target_average_bits: Optional[float] = None | |
| role_cost_weights: dict[str, float] = field( | |
| default_factory=lambda: { | |
| "attention_inputs": 1.0, | |
| "mlp_inputs": 1.0, | |
| "attention_output": 1.05, | |
| "mlp_output": 1.05, | |
| } | |
| ) | |
| base_config: QuantizationConfig = field(default_factory=QuantizationConfig) | |
| importance_threshold_scale: float = 0.0 | |
| class RoleAwareModulePolicy: | |
| """Planner decision for a single module.""" | |
| name: str | |
| layer_idx: int | |
| role: str | |
| out_features: int | |
| in_features: int | |
| num_params: int | |
| action: str | |
| reason: str | |
| boundary_layer: bool | |
| role_priority: int | |
| activation_rms: float | |
| relative_error: float | |
| sensitivity_score: float | |
| bits_if_quantized: float | |
| salient_fraction: float | |
| low_rank_rank: int | |
| class RoleAwareQuantizationPlan: | |
| """Explicit plan for the role-aware ternary path.""" | |
| method_name: str | |
| target_average_bits: Optional[float] | |
| predicted_average_bits: float | |
| predicted_quantized_average_bits: float | |
| predicted_quantized_fraction: float | |
| total_model_params: int | |
| boundary_layers: list[int] | |
| policies: dict[str, RoleAwareModulePolicy] | |
| class SmallModelQuantizationResult: | |
| """Result of applying a role-aware plan in-place.""" | |
| quantized_params: dict[str, GroupwiseTernaryParameter] | |
| stats: dict[str, dict] | |
| quantized_module_names: list[str] | |
| protected_module_names: list[str] | |
| skipped_layer_indices: list[int] | |
| plan: RoleAwareQuantizationPlan | |
| calibration_tune: Optional[dict] = None | |
| def quantize_small_model_inplace( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| config: Optional[SmallModelQuantizationConfig] = None, | |
| plan: Optional[RoleAwareQuantizationPlan] = None, | |
| ) -> SmallModelQuantizationResult: | |
| """Quantize a model in-place with a role-aware small-model plan.""" | |
| if config is None: | |
| config = SmallModelQuantizationConfig() | |
| if plan is None: | |
| plan = build_role_aware_plan(model, calibration_data, config) | |
| decoder_layers, layer_path = _get_decoder_layers(model) | |
| quantized_params: dict[str, GroupwiseTernaryParameter] = {} | |
| stats: dict[str, dict] = {} | |
| quantized_module_names: list[str] = [] | |
| protected_module_names: list[str] = [] | |
| for layer_idx, layer in enumerate(decoder_layers): | |
| layer_prefix = f"{layer_path}.{layer_idx}" | |
| groups, protected = _build_dependency_groups_for_plan( | |
| layer, | |
| layer_prefix, | |
| layer_idx, | |
| plan, | |
| config, | |
| ) | |
| protected_module_names.extend(protected) | |
| for group in groups: | |
| need_outputs = ( | |
| config.low_rank_fit_mode == "activation_regression" | |
| and any(plan.policies[name].low_rank_rank > 0 for name in group) | |
| ) | |
| if need_outputs: | |
| io_captures = _capture_module_io( | |
| model, | |
| calibration_data, | |
| group, | |
| config.calibration_batch_size, | |
| ) | |
| else: | |
| captures = _capture_activations( | |
| model, | |
| calibration_data, | |
| group, | |
| config.calibration_batch_size, | |
| ) | |
| for name, module in group.items(): | |
| policy = plan.policies[name] | |
| quantizer = GroupwiseAsymmetricTernaryQuantizer( | |
| group_size=config.group_size, | |
| n_iter=config.n_iter, | |
| use_activation_aware=config.base_config.use_activation_aware, | |
| salient_fraction=policy.salient_fraction, | |
| low_rank_rank=policy.low_rank_rank, | |
| low_rank_fit_mode=config.low_rank_fit_mode, | |
| low_rank_ridge=config.low_rank_ridge, | |
| low_rank_max_samples=config.low_rank_max_samples, | |
| importance_threshold_scale=config.importance_threshold_scale, | |
| ) | |
| if need_outputs: | |
| activations = io_captures[name]["inputs"] | |
| outputs = io_captures[name]["outputs"] | |
| else: | |
| activations = captures[name] | |
| outputs = None | |
| weight = _get_weight(module) | |
| bias = getattr(module, "bias", None) | |
| param = quantizer.quantize( | |
| weight.detach().float().cpu(), | |
| activations=activations.float().cpu() | |
| if activations is not None | |
| else None, | |
| outputs=outputs.float().cpu() if outputs is not None else None, | |
| bias=bias.detach().float().cpu() if bias is not None else None, | |
| ) | |
| quantized_params[name] = param | |
| stats[name] = compute_groupwise_error(weight, param) | |
| quantized_module_names.append(name) | |
| dequant = param.dequantize().to(weight.device, dtype=weight.dtype) | |
| _set_weight(module, dequant) | |
| skipped_layer_indices = [ | |
| idx | |
| for idx in range(len(decoder_layers)) | |
| if idx < config.n_boundary_layers | |
| or idx >= len(decoder_layers) - config.n_boundary_layers | |
| ] | |
| for name, policy in plan.policies.items(): | |
| if policy.action != "groupwise_ternary": | |
| protected_module_names.append(name) | |
| return SmallModelQuantizationResult( | |
| quantized_params=quantized_params, | |
| stats=stats, | |
| quantized_module_names=quantized_module_names, | |
| protected_module_names=sorted(set(protected_module_names)), | |
| skipped_layer_indices=skipped_layer_indices, | |
| plan=plan, | |
| ) | |
| def build_role_aware_plan( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| config: Optional[SmallModelQuantizationConfig] = None, | |
| ) -> RoleAwareQuantizationPlan: | |
| """Build an explicit role-aware allocation plan before quantization.""" | |
| if config is None: | |
| config = SmallModelQuantizationConfig() | |
| decoder_layers, layer_path = _get_decoder_layers(model) | |
| total_model_params = sum(p.numel() for p in model.parameters()) | |
| boundary_layers = [ | |
| idx | |
| for idx in range(len(decoder_layers)) | |
| if idx < config.n_boundary_layers | |
| or idx >= len(decoder_layers) - config.n_boundary_layers | |
| ] | |
| estimation_fraction = ( | |
| config.min_salient_fraction | |
| if config.adaptive_salient | |
| else config.salient_fraction | |
| ) | |
| policies: dict[str, RoleAwareModulePolicy] = {} | |
| candidates: list[RoleAwareModulePolicy] = [] | |
| residual_spectra: dict[str, torch.Tensor] = {} | |
| predicted_bits = 16.0 * total_model_params | |
| for layer_idx, layer in enumerate(decoder_layers): | |
| layer_prefix = f"{layer_path}.{layer_idx}" | |
| modules = _collect_role_modules(layer, layer_prefix, config.base_config) | |
| if not modules: | |
| continue | |
| if config.low_rank_fit_mode == "activation_regression" and config.low_rank_rank > 0: | |
| io_captures = _capture_module_io( | |
| model, | |
| calibration_data, | |
| modules, | |
| config.calibration_batch_size, | |
| ) | |
| captures = {name: payload["inputs"] for name, payload in io_captures.items()} | |
| outputs = {name: payload["outputs"] for name, payload in io_captures.items()} | |
| else: | |
| captures = _capture_activations( | |
| model, | |
| calibration_data, | |
| modules, | |
| config.calibration_batch_size, | |
| ) | |
| outputs = {} | |
| for name, module in modules.items(): | |
| role = _classify_module(name) | |
| if role is None: | |
| continue | |
| weight = _get_weight(module).detach().float().cpu() | |
| activations = captures[name] | |
| activation_rms = _activation_rms(activations) | |
| role_priority = ROLE_ORDER.index(role) | |
| boundary = layer_idx in boundary_layers | |
| if boundary: | |
| policy = RoleAwareModulePolicy( | |
| name=name, | |
| layer_idx=layer_idx, | |
| role=role, | |
| out_features=weight.shape[0], | |
| in_features=weight.shape[1], | |
| num_params=weight.numel(), | |
| action="fp16", | |
| reason="boundary_layer", | |
| boundary_layer=True, | |
| role_priority=role_priority, | |
| activation_rms=activation_rms, | |
| relative_error=0.0, | |
| sensitivity_score=float("inf"), | |
| bits_if_quantized=16.0, | |
| salient_fraction=0.0, | |
| low_rank_rank=0, | |
| ) | |
| policies[name] = policy | |
| continue | |
| estimate_low_rank = 0 if config.adaptive_low_rank else config.low_rank_rank | |
| quantizer = GroupwiseAsymmetricTernaryQuantizer( | |
| group_size=config.group_size, | |
| n_iter=config.n_iter, | |
| use_activation_aware=config.base_config.use_activation_aware, | |
| salient_fraction=estimation_fraction, | |
| low_rank_rank=estimate_low_rank, | |
| low_rank_fit_mode=config.low_rank_fit_mode, | |
| low_rank_ridge=config.low_rank_ridge, | |
| low_rank_max_samples=config.low_rank_max_samples, | |
| importance_threshold_scale=config.importance_threshold_scale, | |
| ) | |
| param = quantizer.quantize( | |
| weight, | |
| activations=activations.float().cpu() | |
| if activations is not None | |
| else None, | |
| outputs=outputs.get(name).float().cpu() | |
| if outputs.get(name) is not None | |
| else None, | |
| bias=module.bias.detach().float().cpu() | |
| if getattr(module, "bias", None) is not None | |
| else None, | |
| ) | |
| stats = compute_groupwise_error(weight, param) | |
| sensitivity = ( | |
| stats["relative_error"] | |
| * max(activation_rms, 1e-6) | |
| * config.role_cost_weights.get(role, 1.0) | |
| ) | |
| if config.adaptive_low_rank and config.low_rank_rank > 0: | |
| residual = weight - param.dequantize() | |
| residual_spectra[name] = _estimate_residual_spectrum( | |
| residual, | |
| max_rank=config.low_rank_rank, | |
| ) | |
| policy = RoleAwareModulePolicy( | |
| name=name, | |
| layer_idx=layer_idx, | |
| role=role, | |
| out_features=weight.shape[0], | |
| in_features=weight.shape[1], | |
| num_params=weight.numel(), | |
| action="fp16", | |
| reason="unallocated", | |
| boundary_layer=False, | |
| role_priority=role_priority, | |
| activation_rms=activation_rms, | |
| relative_error=stats["relative_error"], | |
| sensitivity_score=sensitivity, | |
| bits_if_quantized=param.effective_bits, | |
| salient_fraction=estimation_fraction, | |
| low_rank_rank=0 if config.adaptive_low_rank else config.low_rank_rank, | |
| ) | |
| policies[name] = policy | |
| candidates.append(policy) | |
| if config.target_average_bits is None: | |
| _assign_practical_actions(candidates, config) | |
| else: | |
| predicted_bits = _assign_budgeted_actions( | |
| candidates, | |
| config, | |
| total_model_params, | |
| ) | |
| if config.target_average_bits is None: | |
| predicted_bits = _predict_total_bits(candidates, total_model_params) | |
| _assign_salient_fractions(candidates, config) | |
| if config.adaptive_low_rank and config.low_rank_rank > 0: | |
| _assign_adaptive_low_rank( | |
| candidates, | |
| config, | |
| total_model_params, | |
| residual_spectra, | |
| ) | |
| predicted_bits = _predict_total_bits(candidates, total_model_params) | |
| quantized = [policy for policy in candidates if policy.action == "groupwise_ternary"] | |
| quantized_params = sum(policy.num_params for policy in quantized) | |
| quantized_weighted_bits = 0.0 | |
| if quantized_params > 0: | |
| quantized_weighted_bits = sum( | |
| policy.num_params * policy.bits_if_quantized for policy in quantized | |
| ) / quantized_params | |
| return RoleAwareQuantizationPlan( | |
| method_name="RAST-small", | |
| target_average_bits=config.target_average_bits, | |
| predicted_average_bits=predicted_bits / total_model_params, | |
| predicted_quantized_average_bits=quantized_weighted_bits, | |
| predicted_quantized_fraction=quantized_params / total_model_params | |
| if total_model_params | |
| else 0.0, | |
| total_model_params=total_model_params, | |
| boundary_layers=boundary_layers, | |
| policies=policies, | |
| ) | |
| def build_all_non_boundary_plan( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| config: Optional[SmallModelQuantizationConfig] = None, | |
| ) -> RoleAwareQuantizationPlan: | |
| """Build an aggressive baseline that ternarizes every non-boundary role module.""" | |
| if config is None: | |
| config = SmallModelQuantizationConfig() | |
| plan = build_role_aware_plan(model, calibration_data, config) | |
| candidates = [policy for policy in plan.policies.values() if not policy.boundary_layer] | |
| for policy in candidates: | |
| policy.action = "groupwise_ternary" | |
| policy.reason = "all_non_boundary_baseline" | |
| policy.salient_fraction = config.salient_fraction | |
| policy.bits_if_quantized = _estimate_effective_bits( | |
| policy, | |
| config.group_size, | |
| policy.salient_fraction, | |
| ) | |
| predicted_bits = _predict_total_bits(candidates, plan.total_model_params) | |
| quantized_params = sum(policy.num_params for policy in candidates) | |
| quantized_weighted_bits = 0.0 | |
| if quantized_params > 0: | |
| quantized_weighted_bits = sum( | |
| policy.num_params * policy.bits_if_quantized for policy in candidates | |
| ) / quantized_params | |
| plan.method_name = "PB-style sparse ternary" | |
| plan.predicted_average_bits = predicted_bits / plan.total_model_params | |
| plan.predicted_quantized_average_bits = quantized_weighted_bits | |
| plan.predicted_quantized_fraction = ( | |
| quantized_params / plan.total_model_params if plan.total_model_params else 0.0 | |
| ) | |
| return plan | |
| def build_sensitivity_only_plan( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| config: Optional[SmallModelQuantizationConfig] = None, | |
| ) -> RoleAwareQuantizationPlan: | |
| """Build a matched-bit baseline that ignores role structure during allocation.""" | |
| if config is None: | |
| config = SmallModelQuantizationConfig(target_average_bits=10.5) | |
| if config.target_average_bits is None: | |
| raise ValueError("Sensitivity-only planning requires target_average_bits.") | |
| plan = build_role_aware_plan(model, calibration_data, config) | |
| candidates = [] | |
| target_bits = float(config.target_average_bits) * plan.total_model_params | |
| for policy in plan.policies.values(): | |
| if policy.boundary_layer: | |
| continue | |
| policy.action = "groupwise_ternary" | |
| policy.salient_fraction = ( | |
| config.min_salient_fraction if config.adaptive_salient else config.salient_fraction | |
| ) | |
| policy.bits_if_quantized = _estimate_effective_bits( | |
| policy, | |
| config.group_size, | |
| policy.salient_fraction, | |
| ) | |
| policy.action = "fp16" | |
| policy.reason = "sensitivity_only_reset" | |
| candidates.append(policy) | |
| predicted_bits = 16.0 * plan.total_model_params | |
| ranked = sorted( | |
| candidates, | |
| key=lambda policy: ( | |
| policy.sensitivity_score / max(16.0 - policy.bits_if_quantized, 1e-6), | |
| policy.sensitivity_score, | |
| policy.layer_idx, | |
| policy.name, | |
| ), | |
| ) | |
| for policy in ranked: | |
| if predicted_bits <= target_bits: | |
| policy.reason = "sensitivity_budget_not_needed" | |
| continue | |
| policy.action = "groupwise_ternary" | |
| policy.reason = "sensitivity_budget_allocation" | |
| predicted_bits -= (16.0 - policy.bits_if_quantized) * policy.num_params | |
| _assign_salient_fractions(candidates, config) | |
| predicted_bits = _predict_total_bits(candidates, plan.total_model_params) | |
| quantized = [policy for policy in candidates if policy.action == "groupwise_ternary"] | |
| quantized_params = sum(policy.num_params for policy in quantized) | |
| quantized_weighted_bits = 0.0 | |
| if quantized_params > 0: | |
| quantized_weighted_bits = sum( | |
| policy.num_params * policy.bits_if_quantized for policy in quantized | |
| ) / quantized_params | |
| plan.method_name = "Sensitivity-only ternary budget" | |
| plan.predicted_average_bits = predicted_bits / plan.total_model_params | |
| plan.predicted_quantized_average_bits = quantized_weighted_bits | |
| plan.predicted_quantized_fraction = ( | |
| quantized_params / plan.total_model_params if plan.total_model_params else 0.0 | |
| ) | |
| return plan | |
| def summarize_small_model_quantization( | |
| result: SmallModelQuantizationResult, | |
| model: nn.Module, | |
| ) -> dict: | |
| """Compute summary statistics for the mixed ternary path.""" | |
| total_model_params = result.plan.total_model_params | |
| quantized_params = sum(param.num_params for param in result.quantized_params.values()) | |
| avg_sparsity = 0.0 | |
| avg_rel_error = 0.0 | |
| avg_bits = 0.0 | |
| total_sparse_nnz = 0 | |
| if result.stats: | |
| avg_sparsity = sum(s["sparsity"] for s in result.stats.values()) / len( | |
| result.stats | |
| ) | |
| avg_rel_error = sum( | |
| s["relative_error"] for s in result.stats.values() | |
| ) / len(result.stats) | |
| avg_bits = sum(s["effective_bits"] for s in result.stats.values()) / len( | |
| result.stats | |
| ) | |
| total_sparse_nnz = sum( | |
| int(s.get("sparse_nnz", 0)) for s in result.stats.values() | |
| ) | |
| full_model_bits = 16.0 * total_model_params | |
| for name, param in result.quantized_params.items(): | |
| full_model_bits -= 16.0 * param.num_params | |
| full_model_bits += param.effective_bits * param.num_params | |
| summary = { | |
| "method_name": result.plan.method_name, | |
| "quantized_params": quantized_params, | |
| "total_model_params": total_model_params, | |
| "quantized_fraction": quantized_params / total_model_params | |
| if total_model_params | |
| else 0.0, | |
| "n_quantized_modules": len(result.quantized_module_names), | |
| "n_protected_modules": len(result.protected_module_names), | |
| "skipped_layer_indices": result.skipped_layer_indices, | |
| "avg_sparsity": avg_sparsity, | |
| "avg_relative_error": avg_rel_error, | |
| "avg_effective_bits": avg_bits, | |
| "full_model_effective_bits": full_model_bits / total_model_params | |
| if total_model_params | |
| else 0.0, | |
| "total_sparse_nnz": total_sparse_nnz, | |
| "target_average_bits": result.plan.target_average_bits, | |
| "predicted_average_bits": result.plan.predicted_average_bits, | |
| } | |
| if result.calibration_tune is not None: | |
| summary["calibration_tune"] = result.calibration_tune | |
| return summary | |
| def plan_to_dict(plan: RoleAwareQuantizationPlan) -> dict: | |
| """Convert a role-aware plan to a JSON-serializable dict.""" | |
| return { | |
| "method_name": plan.method_name, | |
| "target_average_bits": plan.target_average_bits, | |
| "predicted_average_bits": plan.predicted_average_bits, | |
| "predicted_quantized_average_bits": plan.predicted_quantized_average_bits, | |
| "predicted_quantized_fraction": plan.predicted_quantized_fraction, | |
| "total_model_params": plan.total_model_params, | |
| "boundary_layers": plan.boundary_layers, | |
| "policies": { | |
| name: asdict(policy) | |
| for name, policy in sorted(plan.policies.items()) | |
| }, | |
| } | |
| def config_to_dict(config: SmallModelQuantizationConfig) -> dict: | |
| """Convert config to a JSON-friendly dict.""" | |
| data = asdict(config) | |
| if "base_config" in data: | |
| data["base_config"] = asdict(config.base_config) | |
| return data | |
| def tune_low_rank_residuals_inplace( | |
| model: nn.Module, | |
| result: SmallModelQuantizationResult, | |
| calibration_data: torch.Tensor, | |
| n_steps: int = 0, | |
| lr: float = 5e-5, | |
| batch_size: int = 2, | |
| max_seq_len: Optional[int] = None, | |
| max_grad_norm: float = 1.0, | |
| eval_interval: int = 5, | |
| val_fraction: float = 0.25, | |
| behavior_sequences: Optional[list[torch.Tensor]] = None, | |
| behavior_weight: float = 0.0, | |
| calibration_hidden_states: Optional[torch.Tensor] = None, | |
| behavior_hidden_states: Optional[list[torch.Tensor]] = None, | |
| calibration_logit_targets: Optional[dict[str, torch.Tensor]] = None, | |
| behavior_logit_targets: Optional[list[Optional[dict[str, torch.Tensor]]]] = None, | |
| distill_weight: float = 0.0, | |
| behavior_hidden_weight: float = 0.0, | |
| logit_distill_weight: float = 0.0, | |
| behavior_logit_weight: float = 0.0, | |
| entropy_distill_weight: float = 0.0, | |
| behavior_entropy_weight: float = 0.0, | |
| logit_distill_temperature: float = 1.0, | |
| seed: int = 42, | |
| ) -> dict: | |
| """Calibrate low-rank residuals with LM loss plus optional teacher-guided distillation.""" | |
| if n_steps <= 0: | |
| stats = { | |
| "steps": 0, | |
| "n_wrapped_modules": 0, | |
| "n_trainable_params": 0, | |
| } | |
| result.calibration_tune = stats | |
| return stats | |
| wrapped: dict[str, TrainableLowRankLinear] = {} | |
| trainable_params: list[nn.Parameter] = [] | |
| module_device = None | |
| for name, quant_param in result.quantized_params.items(): | |
| if quant_param.lr_U is None or quant_param.lr_V is None: | |
| continue | |
| parent, target_name = _resolve_parent_module(model, name) | |
| original = getattr(parent, target_name) | |
| bias = ( | |
| original.bias.detach().float().cpu() | |
| if getattr(original, "bias", None) is not None | |
| else None | |
| ) | |
| if module_device is None: | |
| module_device = _get_weight(original).device | |
| wrapper = TrainableLowRankLinear(quant_param, bias=bias).to(module_device) | |
| setattr(parent, target_name, wrapper) | |
| wrapped[name] = wrapper | |
| trainable_params.extend([wrapper.U, wrapper.V]) | |
| if not trainable_params: | |
| stats = { | |
| "steps": 0, | |
| "n_wrapped_modules": 0, | |
| "n_trainable_params": 0, | |
| } | |
| result.calibration_tune = stats | |
| return stats | |
| optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.0) | |
| losses = [] | |
| data = calibration_data | |
| if max_seq_len is not None and data.shape[1] > max_seq_len: | |
| data = data[:, :max_seq_len] | |
| behavior_weight = max(0.0, min(float(behavior_weight), 1.0)) | |
| distill_weight = max(0.0, float(distill_weight)) | |
| behavior_hidden_weight = max(0.0, float(behavior_hidden_weight)) | |
| logit_distill_weight = max(0.0, float(logit_distill_weight)) | |
| behavior_logit_weight = max(0.0, float(behavior_logit_weight)) | |
| entropy_distill_weight = max(0.0, float(entropy_distill_weight)) | |
| behavior_entropy_weight = max(0.0, float(behavior_entropy_weight)) | |
| logit_distill_temperature = max(1e-4, float(logit_distill_temperature)) | |
| teacher_sequences = [] | |
| teacher_hidden_sequences = [] | |
| teacher_logit_sequences = [] | |
| if behavior_sequences: | |
| for idx, seq in enumerate(behavior_sequences): | |
| current = seq.detach().cpu().clone() | |
| if current.dim() == 1: | |
| current = current.unsqueeze(0) | |
| if max_seq_len is not None and current.shape[1] > max_seq_len: | |
| current = current[:, :max_seq_len] | |
| if current.numel() > 0: | |
| teacher_sequences.append(current) | |
| if behavior_hidden_states and idx < len(behavior_hidden_states): | |
| hidden = behavior_hidden_states[idx].detach().cpu().clone() | |
| if hidden.dim() == 2: | |
| hidden = hidden.unsqueeze(0) | |
| if max_seq_len is not None and hidden.shape[1] > max_seq_len: | |
| hidden = hidden[:, :max_seq_len] | |
| teacher_hidden_sequences.append(hidden) | |
| else: | |
| teacher_hidden_sequences.append(None) | |
| if behavior_logit_targets and idx < len(behavior_logit_targets): | |
| logit_target = { | |
| "indices": behavior_logit_targets[idx]["indices"].detach().cpu().clone(), | |
| "logits": behavior_logit_targets[idx]["logits"].detach().cpu().clone(), | |
| "entropy": behavior_logit_targets[idx] | |
| .get("entropy") | |
| .detach() | |
| .cpu() | |
| .clone() | |
| if behavior_logit_targets[idx].get("entropy") is not None | |
| else None, | |
| } | |
| if max_seq_len is not None and logit_target["indices"].shape[1] > max_seq_len - 1: | |
| limit = max(max_seq_len - 1, 0) | |
| logit_target["indices"] = logit_target["indices"][:, :limit] | |
| logit_target["logits"] = logit_target["logits"][:, :limit] | |
| if logit_target["entropy"] is not None: | |
| logit_target["entropy"] = logit_target["entropy"][:, :limit] | |
| teacher_logit_sequences.append(logit_target) | |
| else: | |
| teacher_logit_sequences.append(None) | |
| hidden_targets = None | |
| if calibration_hidden_states is not None: | |
| hidden_targets = calibration_hidden_states.detach().cpu().clone() | |
| if hidden_targets.dim() == 2: | |
| hidden_targets = hidden_targets.unsqueeze(0) | |
| if max_seq_len is not None and hidden_targets.shape[1] > max_seq_len: | |
| hidden_targets = hidden_targets[:, :max_seq_len] | |
| logit_targets = None | |
| if calibration_logit_targets is not None: | |
| logit_targets = { | |
| "indices": calibration_logit_targets["indices"].detach().cpu().clone(), | |
| "logits": calibration_logit_targets["logits"].detach().cpu().clone(), | |
| "entropy": calibration_logit_targets.get("entropy").detach().cpu().clone() | |
| if calibration_logit_targets.get("entropy") is not None | |
| else None, | |
| } | |
| if max_seq_len is not None and logit_targets["indices"].shape[1] > max_seq_len - 1: | |
| limit = max(max_seq_len - 1, 0) | |
| logit_targets["indices"] = logit_targets["indices"][:, :limit] | |
| logit_targets["logits"] = logit_targets["logits"][:, :limit] | |
| if logit_targets["entropy"] is not None: | |
| logit_targets["entropy"] = logit_targets["entropy"][:, :limit] | |
| n_val = int(data.shape[0] * val_fraction) | |
| if data.shape[0] >= 4: | |
| n_val = max(1, n_val) | |
| else: | |
| n_val = 0 | |
| if n_val > 0 and n_val < data.shape[0]: | |
| train_data = data[:-n_val] | |
| val_data = data[-n_val:] | |
| train_hidden = hidden_targets[:-n_val] if hidden_targets is not None else None | |
| val_hidden = hidden_targets[-n_val:] if hidden_targets is not None else None | |
| train_logit = ( | |
| { | |
| "indices": logit_targets["indices"][:-n_val], | |
| "logits": logit_targets["logits"][:-n_val], | |
| "entropy": logit_targets["entropy"][:-n_val] | |
| if logit_targets["entropy"] is not None | |
| else None, | |
| } | |
| if logit_targets is not None | |
| else None | |
| ) | |
| val_logit = ( | |
| { | |
| "indices": logit_targets["indices"][-n_val:], | |
| "logits": logit_targets["logits"][-n_val:], | |
| "entropy": logit_targets["entropy"][-n_val:] | |
| if logit_targets["entropy"] is not None | |
| else None, | |
| } | |
| if logit_targets is not None | |
| else None | |
| ) | |
| else: | |
| train_data = data | |
| val_data = None | |
| train_hidden = hidden_targets | |
| val_hidden = None | |
| train_logit = logit_targets | |
| val_logit = None | |
| original_use_cache = getattr(model.config, "use_cache", None) | |
| if original_use_cache is not None: | |
| model.config.use_cache = False | |
| torch.manual_seed(int(seed)) | |
| best_state = None | |
| best_metric = None | |
| best_val_loss = None | |
| best_behavior_loss = None | |
| best_hidden_loss = None | |
| best_behavior_hidden_loss = None | |
| best_logit_loss = None | |
| best_behavior_logit_loss = None | |
| best_entropy_loss = None | |
| best_behavior_entropy_loss = None | |
| behavior_losses = [] | |
| hidden_losses = [] | |
| behavior_hidden_losses = [] | |
| logit_losses = [] | |
| behavior_logit_losses = [] | |
| entropy_losses = [] | |
| behavior_entropy_losses = [] | |
| teacher_index = 0 | |
| def _slice_logit_targets( | |
| targets: Optional[dict[str, torch.Tensor]], | |
| indices: torch.Tensor, | |
| ) -> Optional[dict[str, torch.Tensor]]: | |
| if targets is None: | |
| return None | |
| idx_cpu = indices.detach().cpu() | |
| return { | |
| "indices": targets["indices"].index_select(0, idx_cpu), | |
| "logits": targets["logits"].index_select(0, idx_cpu), | |
| "entropy": targets["entropy"].index_select(0, idx_cpu) | |
| if targets.get("entropy") is not None | |
| else None, | |
| } | |
| def _topk_logit_kl( | |
| student_logits: torch.Tensor, | |
| targets: Optional[dict[str, torch.Tensor]], | |
| ) -> Optional[torch.Tensor]: | |
| if targets is None: | |
| return None | |
| target_indices = targets["indices"] | |
| target_logits = targets["logits"] | |
| if target_indices.numel() == 0 or target_logits.numel() == 0: | |
| return None | |
| seq_len = min(student_logits.shape[1] - 1, target_indices.shape[1]) | |
| if seq_len <= 0: | |
| return None | |
| gather_indices = target_indices[:, :seq_len].to(student_logits.device).long() | |
| teacher_topk_logits = ( | |
| target_logits[:, :seq_len].to(student_logits.device).float() | |
| ) | |
| student_next_logits = student_logits[:, :seq_len, :].float() | |
| student_topk_logits = torch.gather(student_next_logits, -1, gather_indices) | |
| student_log_probs = F.log_softmax( | |
| student_topk_logits / logit_distill_temperature, | |
| dim=-1, | |
| ) | |
| teacher_probs = F.softmax( | |
| teacher_topk_logits / logit_distill_temperature, | |
| dim=-1, | |
| ) | |
| return ( | |
| F.kl_div( | |
| student_log_probs, | |
| teacher_probs, | |
| reduction="none", | |
| ).sum(dim=-1).mean() | |
| * (logit_distill_temperature**2) | |
| ) | |
| def _entropy_floor_loss( | |
| student_logits: torch.Tensor, | |
| targets: Optional[dict[str, torch.Tensor]], | |
| ) -> Optional[torch.Tensor]: | |
| if targets is None or targets.get("entropy") is None: | |
| return None | |
| target_entropy = targets["entropy"] | |
| if target_entropy.numel() == 0: | |
| return None | |
| seq_len = min(student_logits.shape[1] - 1, target_entropy.shape[1]) | |
| if seq_len <= 0: | |
| return None | |
| student_next_logits = student_logits[:, :seq_len, :].float() | |
| student_log_probs = F.log_softmax(student_next_logits, dim=-1) | |
| student_probs = student_log_probs.exp() | |
| normalizer = math.log(max(student_next_logits.shape[-1], 2)) | |
| student_entropy = -(student_probs * student_log_probs).sum(dim=-1) / normalizer | |
| teacher_entropy = target_entropy[:, :seq_len].to(student_logits.device).float() | |
| return F.relu(teacher_entropy - student_entropy).square().mean() | |
| def evaluate_loss(split: torch.Tensor) -> float: | |
| if split is None or split.numel() == 0: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for start in range(0, split.shape[0], batch_size): | |
| batch = split[start : start + batch_size] | |
| values.append(float(model(batch, labels=batch).loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) | |
| def evaluate_logit_loss( | |
| split: Optional[torch.Tensor], | |
| targets: Optional[dict[str, torch.Tensor]], | |
| ) -> float: | |
| if split is None or targets is None or split.numel() == 0: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for start in range(0, split.shape[0], batch_size): | |
| batch = split[start : start + batch_size] | |
| batch_targets = { | |
| "indices": targets["indices"][start : start + batch.shape[0]], | |
| "logits": targets["logits"][start : start + batch.shape[0]], | |
| } | |
| logits = model(batch).logits | |
| loss = _topk_logit_kl(logits, batch_targets) | |
| if loss is not None: | |
| values.append(float(loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) if values else float("nan") | |
| def evaluate_entropy_loss( | |
| split: Optional[torch.Tensor], | |
| targets: Optional[dict[str, torch.Tensor]], | |
| ) -> float: | |
| if split is None or targets is None or targets.get("entropy") is None or split.numel() == 0: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for start in range(0, split.shape[0], batch_size): | |
| batch = split[start : start + batch_size] | |
| batch_targets = { | |
| "entropy": targets["entropy"][start : start + batch.shape[0]], | |
| } | |
| logits = model(batch).logits | |
| loss = _entropy_floor_loss(logits, batch_targets) | |
| if loss is not None: | |
| values.append(float(loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) if values else float("nan") | |
| def evaluate_behavior_loss(sequences: list[torch.Tensor]) -> float: | |
| if not sequences: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for seq in sequences: | |
| batch = seq.to(module_device) | |
| values.append(float(model(batch, labels=batch).loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) | |
| def evaluate_hidden_loss( | |
| split: Optional[torch.Tensor], | |
| targets: Optional[torch.Tensor], | |
| ) -> float: | |
| if split is None or targets is None or split.numel() == 0: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for start in range(0, split.shape[0], batch_size): | |
| batch = split[start : start + batch_size] | |
| batch_targets = targets[start : start + batch.shape[0]].to(module_device) | |
| outputs = model(batch, output_hidden_states=True) | |
| student = outputs.hidden_states[-1].float() | |
| values.append(float(F.mse_loss(student, batch_targets.float()).item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) | |
| def evaluate_behavior_hidden_loss( | |
| sequences: list[torch.Tensor], | |
| targets: list[Optional[torch.Tensor]], | |
| ) -> float: | |
| pairs = [ | |
| (seq, target) | |
| for seq, target in zip(sequences, targets) | |
| if target is not None | |
| ] | |
| if not pairs: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for seq, target in pairs: | |
| batch = seq.to(module_device) | |
| outputs = model(batch, output_hidden_states=True) | |
| student = outputs.hidden_states[-1].float() | |
| values.append(float(F.mse_loss(student, target.to(module_device).float()).item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) | |
| def evaluate_behavior_logit_loss( | |
| sequences: list[torch.Tensor], | |
| targets: list[Optional[dict[str, torch.Tensor]]], | |
| ) -> float: | |
| pairs = [ | |
| (seq, target) | |
| for seq, target in zip(sequences, targets) | |
| if target is not None | |
| ] | |
| if not pairs: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for seq, target in pairs: | |
| logits = model(seq.to(module_device)).logits | |
| loss = _topk_logit_kl(logits, target) | |
| if loss is not None: | |
| values.append(float(loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) if values else float("nan") | |
| def evaluate_behavior_entropy_loss( | |
| sequences: list[torch.Tensor], | |
| targets: list[Optional[dict[str, torch.Tensor]]], | |
| ) -> float: | |
| pairs = [ | |
| (seq, target) | |
| for seq, target in zip(sequences, targets) | |
| if target is not None and target.get("entropy") is not None | |
| ] | |
| if not pairs: | |
| return float("nan") | |
| model.eval() | |
| values = [] | |
| with torch.no_grad(): | |
| for seq, target in pairs: | |
| logits = model(seq.to(module_device)).logits | |
| loss = _entropy_floor_loss(logits, target) | |
| if loss is not None: | |
| values.append(float(loss.item())) | |
| model.train() | |
| return sum(values) / max(len(values), 1) if values else float("nan") | |
| model.train() | |
| need_hidden_outputs = ( | |
| distill_weight > 0.0 | |
| or (teacher_sequences and behavior_hidden_weight > 0.0) | |
| ) | |
| for step in range(n_steps): | |
| idx = torch.randint(0, train_data.shape[0], (batch_size,), device=train_data.device) | |
| batch = train_data.index_select(0, idx) | |
| optimizer.zero_grad(set_to_none=True) | |
| outputs = model( | |
| batch, | |
| labels=batch, | |
| output_hidden_states=need_hidden_outputs, | |
| ) | |
| lm_loss = outputs.loss | |
| loss = lm_loss | |
| hidden_loss = None | |
| logit_loss = None | |
| entropy_loss = None | |
| if train_hidden is not None and distill_weight > 0.0: | |
| target_hidden = train_hidden.index_select(0, idx.cpu()).to(module_device) | |
| hidden_loss = F.mse_loss( | |
| outputs.hidden_states[-1].float(), | |
| target_hidden.float(), | |
| ) | |
| if train_logit is not None and logit_distill_weight > 0.0: | |
| batch_targets = _slice_logit_targets(train_logit, idx) | |
| logit_loss = _topk_logit_kl(outputs.logits, batch_targets) | |
| if train_logit is not None and entropy_distill_weight > 0.0: | |
| batch_targets = _slice_logit_targets(train_logit, idx) | |
| entropy_loss = _entropy_floor_loss(outputs.logits, batch_targets) | |
| behavior_loss = None | |
| behavior_hidden_loss = None | |
| behavior_logit_loss = None | |
| behavior_entropy_loss = None | |
| if teacher_sequences and ( | |
| behavior_weight > 0.0 | |
| or behavior_hidden_weight > 0.0 | |
| or behavior_logit_weight > 0.0 | |
| or behavior_entropy_weight > 0.0 | |
| ): | |
| behavior_batch = teacher_sequences[teacher_index % len(teacher_sequences)].to( | |
| module_device | |
| ) | |
| behavior_hidden_target = teacher_hidden_sequences[ | |
| teacher_index % len(teacher_hidden_sequences) | |
| ] | |
| behavior_logit_target = teacher_logit_sequences[ | |
| teacher_index % len(teacher_logit_sequences) | |
| ] | |
| teacher_index += 1 | |
| behavior_outputs = model( | |
| behavior_batch, | |
| labels=behavior_batch, | |
| output_hidden_states=behavior_hidden_target is not None | |
| and behavior_hidden_weight > 0.0, | |
| ) | |
| behavior_loss = behavior_outputs.loss | |
| if behavior_hidden_target is not None and behavior_hidden_weight > 0.0: | |
| behavior_hidden_loss = F.mse_loss( | |
| behavior_outputs.hidden_states[-1].float(), | |
| behavior_hidden_target.to(module_device).float(), | |
| ) | |
| if behavior_logit_target is not None and behavior_logit_weight > 0.0: | |
| behavior_logit_loss = _topk_logit_kl( | |
| behavior_outputs.logits, | |
| behavior_logit_target, | |
| ) | |
| if behavior_logit_target is not None and behavior_entropy_weight > 0.0: | |
| behavior_entropy_loss = _entropy_floor_loss( | |
| behavior_outputs.logits, | |
| behavior_logit_target, | |
| ) | |
| if behavior_weight >= 1.0 and hidden_loss is None: | |
| loss = behavior_loss | |
| else: | |
| loss = (1.0 - behavior_weight) * lm_loss + behavior_weight * behavior_loss | |
| if hidden_loss is not None: | |
| loss = loss + distill_weight * hidden_loss | |
| if logit_loss is not None: | |
| loss = loss + logit_distill_weight * logit_loss | |
| if entropy_loss is not None: | |
| loss = loss + entropy_distill_weight * entropy_loss | |
| if behavior_hidden_loss is not None: | |
| loss = loss + behavior_hidden_weight * behavior_hidden_loss | |
| if behavior_logit_loss is not None: | |
| loss = loss + behavior_logit_weight * behavior_logit_loss | |
| if behavior_entropy_loss is not None: | |
| loss = loss + behavior_entropy_weight * behavior_entropy_loss | |
| loss.backward() | |
| if max_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| optimizer.step() | |
| losses.append(float(loss.item())) | |
| if behavior_loss is not None: | |
| behavior_losses.append(float(behavior_loss.item())) | |
| if hidden_loss is not None: | |
| hidden_losses.append(float(hidden_loss.item())) | |
| if behavior_hidden_loss is not None: | |
| behavior_hidden_losses.append(float(behavior_hidden_loss.item())) | |
| if logit_loss is not None: | |
| logit_losses.append(float(logit_loss.item())) | |
| if behavior_logit_loss is not None: | |
| behavior_logit_losses.append(float(behavior_logit_loss.item())) | |
| if entropy_loss is not None: | |
| entropy_losses.append(float(entropy_loss.item())) | |
| if behavior_entropy_loss is not None: | |
| behavior_entropy_losses.append(float(behavior_entropy_loss.item())) | |
| if val_data is not None and ((step + 1) % max(eval_interval, 1) == 0 or step == 0): | |
| val_loss = evaluate_loss(val_data) | |
| behavior_val_loss = evaluate_behavior_loss(teacher_sequences) | |
| hidden_val_loss = evaluate_hidden_loss(val_data, val_hidden) | |
| behavior_hidden_val_loss = evaluate_behavior_hidden_loss( | |
| teacher_sequences, | |
| teacher_hidden_sequences, | |
| ) | |
| logit_val_loss = evaluate_logit_loss(val_data, val_logit) | |
| entropy_val_loss = evaluate_entropy_loss(val_data, val_logit) | |
| behavior_logit_val_loss = evaluate_behavior_logit_loss( | |
| teacher_sequences, | |
| teacher_logit_sequences, | |
| ) | |
| behavior_entropy_val_loss = evaluate_behavior_entropy_loss( | |
| teacher_sequences, | |
| teacher_logit_sequences, | |
| ) | |
| monitor = val_loss | |
| if teacher_sequences and not math.isnan(behavior_val_loss): | |
| if math.isnan(monitor): | |
| monitor = behavior_val_loss | |
| else: | |
| monitor = (1.0 - behavior_weight) * monitor + behavior_weight * behavior_val_loss | |
| if not math.isnan(hidden_val_loss): | |
| if math.isnan(monitor): | |
| monitor = distill_weight * hidden_val_loss | |
| else: | |
| monitor = monitor + distill_weight * hidden_val_loss | |
| if not math.isnan(behavior_hidden_val_loss): | |
| if math.isnan(monitor): | |
| monitor = behavior_hidden_weight * behavior_hidden_val_loss | |
| else: | |
| monitor = monitor + behavior_hidden_weight * behavior_hidden_val_loss | |
| if not math.isnan(logit_val_loss): | |
| if math.isnan(monitor): | |
| monitor = logit_distill_weight * logit_val_loss | |
| else: | |
| monitor = monitor + logit_distill_weight * logit_val_loss | |
| if not math.isnan(entropy_val_loss): | |
| if math.isnan(monitor): | |
| monitor = entropy_distill_weight * entropy_val_loss | |
| else: | |
| monitor = monitor + entropy_distill_weight * entropy_val_loss | |
| if not math.isnan(behavior_logit_val_loss): | |
| if math.isnan(monitor): | |
| monitor = behavior_logit_weight * behavior_logit_val_loss | |
| else: | |
| monitor = monitor + behavior_logit_weight * behavior_logit_val_loss | |
| if not math.isnan(behavior_entropy_val_loss): | |
| if math.isnan(monitor): | |
| monitor = behavior_entropy_weight * behavior_entropy_val_loss | |
| else: | |
| monitor = monitor + behavior_entropy_weight * behavior_entropy_val_loss | |
| if best_metric is None or monitor < best_metric: | |
| best_metric = monitor | |
| best_val_loss = val_loss | |
| if teacher_sequences and not math.isnan(behavior_val_loss): | |
| best_behavior_loss = behavior_val_loss | |
| if not math.isnan(hidden_val_loss): | |
| best_hidden_loss = hidden_val_loss | |
| if not math.isnan(behavior_hidden_val_loss): | |
| best_behavior_hidden_loss = behavior_hidden_val_loss | |
| if not math.isnan(logit_val_loss): | |
| best_logit_loss = logit_val_loss | |
| if not math.isnan(behavior_logit_val_loss): | |
| best_behavior_logit_loss = behavior_logit_val_loss | |
| if not math.isnan(entropy_val_loss): | |
| best_entropy_loss = entropy_val_loss | |
| if not math.isnan(behavior_entropy_val_loss): | |
| best_behavior_entropy_loss = behavior_entropy_val_loss | |
| best_state = { | |
| name: ( | |
| wrapper.U.detach().cpu().clone(), | |
| wrapper.V.detach().cpu().clone(), | |
| ) | |
| for name, wrapper in wrapped.items() | |
| } | |
| model.eval() | |
| if original_use_cache is not None: | |
| model.config.use_cache = original_use_cache | |
| if best_state is not None: | |
| for name, wrapper in wrapped.items(): | |
| best_u, best_v = best_state[name] | |
| wrapper.U.data.copy_(best_u.to(wrapper.U.device)) | |
| wrapper.V.data.copy_(best_v.to(wrapper.V.device)) | |
| for name, wrapper in wrapped.items(): | |
| result.quantized_params[name].lr_U = wrapper.U.detach().cpu().to(torch.float16) | |
| result.quantized_params[name].lr_V = wrapper.V.detach().cpu().to(torch.float16) | |
| stats = { | |
| "steps": n_steps, | |
| "learning_rate": lr, | |
| "batch_size": batch_size, | |
| "max_seq_len": max_seq_len, | |
| "n_wrapped_modules": len(wrapped), | |
| "n_trainable_params": int(sum(p.numel() for p in trainable_params)), | |
| "initial_loss": losses[0], | |
| "final_loss": losses[-1], | |
| "best_loss": min(losses), | |
| "best_val_loss": best_val_loss, | |
| "best_monitor": best_metric, | |
| "behavior_weight": behavior_weight, | |
| "distill_weight": distill_weight, | |
| "behavior_hidden_weight": behavior_hidden_weight, | |
| "logit_distill_weight": logit_distill_weight, | |
| "behavior_logit_weight": behavior_logit_weight, | |
| "entropy_distill_weight": entropy_distill_weight, | |
| "behavior_entropy_weight": behavior_entropy_weight, | |
| "logit_distill_temperature": logit_distill_temperature, | |
| "n_behavior_sequences": len(teacher_sequences), | |
| "initial_behavior_loss": behavior_losses[0] if behavior_losses else None, | |
| "final_behavior_loss": behavior_losses[-1] if behavior_losses else None, | |
| "best_behavior_loss": best_behavior_loss, | |
| "initial_hidden_loss": hidden_losses[0] if hidden_losses else None, | |
| "final_hidden_loss": hidden_losses[-1] if hidden_losses else None, | |
| "best_hidden_loss": best_hidden_loss, | |
| "initial_logit_loss": logit_losses[0] if logit_losses else None, | |
| "final_logit_loss": logit_losses[-1] if logit_losses else None, | |
| "best_logit_loss": best_logit_loss, | |
| "initial_entropy_loss": entropy_losses[0] if entropy_losses else None, | |
| "final_entropy_loss": entropy_losses[-1] if entropy_losses else None, | |
| "best_entropy_loss": best_entropy_loss, | |
| "initial_behavior_hidden_loss": behavior_hidden_losses[0] | |
| if behavior_hidden_losses | |
| else None, | |
| "final_behavior_hidden_loss": behavior_hidden_losses[-1] | |
| if behavior_hidden_losses | |
| else None, | |
| "best_behavior_hidden_loss": best_behavior_hidden_loss, | |
| "initial_behavior_logit_loss": behavior_logit_losses[0] | |
| if behavior_logit_losses | |
| else None, | |
| "final_behavior_logit_loss": behavior_logit_losses[-1] | |
| if behavior_logit_losses | |
| else None, | |
| "best_behavior_logit_loss": best_behavior_logit_loss, | |
| "initial_behavior_entropy_loss": behavior_entropy_losses[0] | |
| if behavior_entropy_losses | |
| else None, | |
| "final_behavior_entropy_loss": behavior_entropy_losses[-1] | |
| if behavior_entropy_losses | |
| else None, | |
| "best_behavior_entropy_loss": best_behavior_entropy_loss, | |
| } | |
| result.calibration_tune = stats | |
| return stats | |
| def _assign_practical_actions( | |
| candidates: list[RoleAwareModulePolicy], | |
| config: SmallModelQuantizationConfig, | |
| ) -> None: | |
| for policy in candidates: | |
| quantize = policy.role in {"attention_inputs", "mlp_inputs"} | |
| if policy.role == "attention_output" and config.quantize_attention_output: | |
| quantize = True | |
| if policy.role == "mlp_output" and config.quantize_mlp_output: | |
| quantize = True | |
| if quantize: | |
| policy.action = "groupwise_ternary" | |
| policy.reason = "practical_role_policy" | |
| else: | |
| policy.action = "fp16" | |
| policy.reason = "protected_role_policy" | |
| def _assign_budgeted_actions( | |
| candidates: list[RoleAwareModulePolicy], | |
| config: SmallModelQuantizationConfig, | |
| total_model_params: int, | |
| ) -> float: | |
| predicted_bits = 16.0 * total_model_params | |
| target_bits = float(config.target_average_bits) * total_model_params | |
| ranked = sorted( | |
| candidates, | |
| key=lambda policy: ( | |
| policy.sensitivity_score | |
| / max(16.0 - policy.bits_if_quantized, 1e-6), | |
| policy.sensitivity_score, | |
| policy.role_priority, | |
| ), | |
| ) | |
| for policy in ranked: | |
| if predicted_bits <= target_bits: | |
| policy.action = "fp16" | |
| policy.reason = "budget_not_needed" | |
| continue | |
| policy.action = "groupwise_ternary" | |
| policy.reason = "budget_allocation" | |
| predicted_bits -= (16.0 - policy.bits_if_quantized) * policy.num_params | |
| for policy in ranked: | |
| if policy.action == "fp16" and policy.reason == "unallocated": | |
| policy.reason = "protected_by_budget" | |
| return predicted_bits | |
| def _assign_salient_fractions( | |
| candidates: list[RoleAwareModulePolicy], | |
| config: SmallModelQuantizationConfig, | |
| ) -> None: | |
| selected = [policy for policy in candidates if policy.action == "groupwise_ternary"] | |
| if not selected: | |
| return | |
| if not config.adaptive_salient: | |
| for policy in selected: | |
| policy.salient_fraction = config.salient_fraction | |
| policy.bits_if_quantized = _estimate_effective_bits( | |
| policy, | |
| config.group_size, | |
| policy.salient_fraction, | |
| ) | |
| return | |
| scores = [policy.sensitivity_score for policy in selected] | |
| min_score = min(scores) | |
| max_score = max(scores) | |
| denom = max(max_score - min_score, 1e-8) | |
| for policy in selected: | |
| normalized = (policy.sensitivity_score - min_score) / denom | |
| fraction = config.min_salient_fraction + normalized * ( | |
| config.max_salient_fraction - config.min_salient_fraction | |
| ) | |
| if policy.role in {"attention_output", "mlp_output"}: | |
| fraction = max(fraction, config.max_salient_fraction) | |
| policy.salient_fraction = fraction | |
| policy.bits_if_quantized = _estimate_effective_bits( | |
| policy, | |
| config.group_size, | |
| fraction, | |
| ) | |
| def _estimate_residual_spectrum( | |
| residual: torch.Tensor, | |
| max_rank: int, | |
| ) -> torch.Tensor: | |
| if max_rank <= 0: | |
| return torch.zeros(0, dtype=torch.float32) | |
| limit = min(max_rank, residual.shape[0], residual.shape[1]) | |
| if limit <= 0: | |
| return torch.zeros(0, dtype=torch.float32) | |
| singular_values = torch.linalg.svdvals(residual.float().cpu()) | |
| return singular_values[:limit].detach().cpu() | |
| def _assign_adaptive_low_rank( | |
| candidates: list[RoleAwareModulePolicy], | |
| config: SmallModelQuantizationConfig, | |
| total_model_params: int, | |
| residual_spectra: dict[str, torch.Tensor], | |
| ) -> None: | |
| selected = [policy for policy in candidates if policy.action == "groupwise_ternary"] | |
| if not selected: | |
| return | |
| for policy in selected: | |
| policy.low_rank_rank = 0 | |
| policy.bits_if_quantized = _estimate_effective_bits( | |
| policy, | |
| config.group_size, | |
| policy.salient_fraction, | |
| ) | |
| base_bits = _predict_total_bits(candidates, total_model_params) | |
| uniform_extra_bits = sum( | |
| 16.0 * config.low_rank_rank * (policy.out_features + policy.in_features) | |
| for policy in selected | |
| ) | |
| if config.low_rank_target_average_bits is not None: | |
| target_bits = float(config.low_rank_target_average_bits) * total_model_params | |
| elif config.target_average_bits is not None: | |
| target_bits = float(config.target_average_bits) * total_model_params | |
| else: | |
| target_bits = base_bits + uniform_extra_bits | |
| remaining_bits = max(0.0, target_bits - base_bits) | |
| chunk_rank = max(1, min(config.low_rank_chunk_rank, config.low_rank_rank)) | |
| current_rank = {policy.name: 0 for policy in selected} | |
| max_rank = { | |
| policy.name: min(config.low_rank_rank, policy.out_features, policy.in_features) | |
| for policy in selected | |
| } | |
| while remaining_bits > 0: | |
| best_policy = None | |
| best_rank_step = 0 | |
| best_score = 0.0 | |
| best_cost = 0.0 | |
| for policy in selected: | |
| rank_now = current_rank[policy.name] | |
| rank_limit = max_rank[policy.name] | |
| if rank_now >= rank_limit: | |
| continue | |
| step = min(chunk_rank, rank_limit - rank_now) | |
| cost = 16.0 * step * (policy.out_features + policy.in_features) | |
| if cost > remaining_bits + 1e-6: | |
| continue | |
| spectrum = residual_spectra.get(policy.name) | |
| if spectrum is None or rank_now >= spectrum.numel(): | |
| continue | |
| gain = float((spectrum[rank_now : rank_now + step] ** 2).sum().item()) | |
| if gain <= 0.0: | |
| continue | |
| weighted_gain = gain * max(policy.activation_rms, 1e-6) | |
| score = weighted_gain / max(cost, 1e-6) | |
| if score > best_score: | |
| best_score = score | |
| best_policy = policy | |
| best_rank_step = step | |
| best_cost = cost | |
| if best_policy is None or best_rank_step <= 0: | |
| break | |
| current_rank[best_policy.name] += best_rank_step | |
| best_policy.low_rank_rank = current_rank[best_policy.name] | |
| best_policy.bits_if_quantized = _estimate_effective_bits( | |
| best_policy, | |
| config.group_size, | |
| best_policy.salient_fraction, | |
| ) | |
| remaining_bits -= best_cost | |
| def _estimate_effective_bits( | |
| policy: RoleAwareModulePolicy, | |
| group_size: int, | |
| salient_fraction: float, | |
| ) -> float: | |
| if policy.action != "groupwise_ternary": | |
| return 16.0 | |
| num_params = policy.num_params | |
| out_features = policy.out_features | |
| in_features = policy.in_features | |
| n_groups = max(1, math.ceil(in_features / max(group_size, 1))) | |
| code_bits = 2 * num_params | |
| group_param_bits = 16 * 2 * out_features * n_groups | |
| sparse_bits = int(max(1, salient_fraction * num_params)) * (32 + 16) | |
| low_rank_bits = 16 * policy.low_rank_rank * (out_features + in_features) | |
| return (code_bits + group_param_bits + sparse_bits + low_rank_bits) / max( | |
| num_params, 1 | |
| ) | |
| def _predict_total_bits( | |
| candidates: list[RoleAwareModulePolicy], | |
| total_model_params: int, | |
| ) -> float: | |
| predicted_bits = 16.0 * total_model_params | |
| for policy in candidates: | |
| if policy.action == "groupwise_ternary": | |
| predicted_bits -= 16.0 * policy.num_params | |
| predicted_bits += policy.bits_if_quantized * policy.num_params | |
| return predicted_bits | |
| def _collect_role_modules( | |
| layer: nn.Module, | |
| layer_prefix: str, | |
| base_config: QuantizationConfig, | |
| ) -> dict[str, nn.Module]: | |
| modules = {} | |
| for name, module in layer.named_modules(): | |
| full_name = f"{layer_prefix}.{name}" if name else layer_prefix | |
| if not (_is_linear_layer(module) and _should_quantize(full_name, base_config)): | |
| continue | |
| if _classify_module(full_name) is not None: | |
| modules[full_name] = module | |
| return modules | |
| def _capture_activations( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| modules: dict[str, nn.Module], | |
| batch_size: int, | |
| ) -> dict[str, Optional[torch.Tensor]]: | |
| captures: dict[str, ActivationCapture] = {} | |
| for name, module in modules.items(): | |
| capture = ActivationCapture() | |
| capture.register(module) | |
| captures[name] = capture | |
| _run_model_forward(model, calibration_data, batch_size=batch_size) | |
| outputs: dict[str, Optional[torch.Tensor]] = {} | |
| for name, capture in captures.items(): | |
| outputs[name] = capture.get_activations() | |
| capture.remove() | |
| return outputs | |
| def _capture_module_io( | |
| model: nn.Module, | |
| calibration_data: torch.Tensor, | |
| modules: dict[str, nn.Module], | |
| batch_size: int, | |
| ) -> dict[str, dict[str, Optional[torch.Tensor]]]: | |
| captures: dict[str, ModuleIOCapture] = {} | |
| for name, module in modules.items(): | |
| capture = ModuleIOCapture() | |
| capture.register(module) | |
| captures[name] = capture | |
| _run_model_forward(model, calibration_data, batch_size=batch_size) | |
| outputs: dict[str, dict[str, Optional[torch.Tensor]]] = {} | |
| for name, capture in captures.items(): | |
| outputs[name] = { | |
| "inputs": capture.get_inputs(), | |
| "outputs": capture.get_outputs(), | |
| } | |
| capture.remove() | |
| return outputs | |
| def _resolve_parent_module(model: nn.Module, name: str) -> tuple[nn.Module, str]: | |
| parent = model | |
| parts = name.split(".") | |
| for part in parts[:-1]: | |
| parent = getattr(parent, part) | |
| return parent, parts[-1] | |
| def _activation_rms(activations: Optional[torch.Tensor]) -> float: | |
| if activations is None or activations.numel() == 0: | |
| return 1.0 | |
| return activations.float().square().mean().sqrt().item() | |
| def _build_dependency_groups_for_plan( | |
| layer: nn.Module, | |
| layer_prefix: str, | |
| layer_idx: int, | |
| plan: RoleAwareQuantizationPlan, | |
| config: SmallModelQuantizationConfig, | |
| ) -> tuple[list[dict[str, nn.Module]], list[str]]: | |
| grouped_modules = {role: {} for role in ROLE_ORDER} | |
| protected: list[str] = [] | |
| for name, module in layer.named_modules(): | |
| full_name = f"{layer_prefix}.{name}" if name else layer_prefix | |
| if not (_is_linear_layer(module) and _should_quantize(full_name, config.base_config)): | |
| continue | |
| role = _classify_module(full_name) | |
| if role is None: | |
| continue | |
| policy = plan.policies.get(full_name) | |
| if policy is None or policy.action != "groupwise_ternary": | |
| protected.append(full_name) | |
| continue | |
| grouped_modules[role][full_name] = module | |
| ordered_groups = [grouped_modules[role] for role in ROLE_ORDER if grouped_modules[role]] | |
| return ordered_groups, protected | |
| def _classify_module(full_name: str) -> Optional[str]: | |
| target_name = full_name.split(".")[-1] | |
| if target_name in {"q_proj", "k_proj", "v_proj", "query_key_value", "c_attn"}: | |
| return "attention_inputs" | |
| if target_name in {"gate_proj", "up_proj", "w1", "w3", "dense_h_to_4h", "c_fc"}: | |
| return "mlp_inputs" | |
| if target_name in {"o_proj", "dense"}: | |
| return "attention_output" | |
| if target_name in {"down_proj", "w2", "dense_4h_to_h"}: | |
| return "mlp_output" | |
| if target_name == "c_proj": | |
| if ".attn." in full_name or ".self_attn." in full_name: | |
| return "attention_output" | |
| if ".mlp." in full_name: | |
| return "mlp_output" | |
| return None | |