ternary-quant-demo / ternary_quant /quantizer_small.py
AsadIsmail's picture
Update Qwen2-VL: now text+vision backbone quantized (fixed qkv + NaN bug)
0b1047e verified
"""
Role-aware ternary quantization utilities for small language models.
This module contains two layers of functionality:
1. A practical mixed-precision recipe that keeps fragile merge projections
at full precision and ternarizes the more robust input projections.
2. A paper-oriented planner that turns that recipe into an explicit budgeted
allocator with per-module sensitivity scores and per-module sparse residual
budgets.
"""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ternary_quant.pipeline import (
ActivationCapture,
QuantizationConfig,
_get_decoder_layers,
_get_weight,
_is_linear_layer,
_run_model_forward,
_set_weight,
_should_quantize,
)
ROLE_ORDER = (
"attention_inputs",
"attention_output",
"mlp_inputs",
"mlp_output",
)
class ModuleIOCapture:
"""Hook-based capture for flattened module inputs and outputs."""
def __init__(self):
self.inputs = []
self.outputs = []
self.hook = None
def register(self, module: nn.Module) -> None:
self.hook = module.register_forward_hook(self._hook_fn)
def _hook_fn(self, module, input, output) -> None:
inp = input[0].detach()
if inp.dim() == 3:
inp = inp.reshape(-1, inp.shape[-1])
out = output[0] if isinstance(output, tuple) else output
out = out.detach()
if out.dim() == 3:
out = out.reshape(-1, out.shape[-1])
self.inputs.append(inp.cpu())
self.outputs.append(out.cpu())
def get_inputs(self) -> Optional[torch.Tensor]:
if not self.inputs:
return None
return torch.cat(self.inputs, dim=0)
def get_outputs(self) -> Optional[torch.Tensor]:
if not self.outputs:
return None
return torch.cat(self.outputs, dim=0)
def remove(self) -> None:
if self.hook is not None:
self.hook.remove()
self.hook = None
self.inputs = []
self.outputs = []
@dataclass
class GroupwiseTernaryParameter:
"""Grouped asymmetric ternary representation: alpha_g * T + mu_g + sparse."""
ternary_codes: torch.Tensor
group_alpha: torch.Tensor
group_mu: torch.Tensor
group_size: int
sparse_indices: Optional[torch.Tensor]
sparse_residual: Optional[torch.Tensor]
lr_U: Optional[torch.Tensor]
lr_V: Optional[torch.Tensor]
original_shape: tuple[int, int]
original_dtype: torch.dtype
def dequantize(self) -> torch.Tensor:
out_features, in_features = self.original_shape
alpha = self.group_alpha.float().repeat_interleave(self.group_size, dim=1)[
:, :in_features
]
mu = self.group_mu.float().repeat_interleave(self.group_size, dim=1)[
:, :in_features
]
weight = alpha * self.ternary_codes.float() + mu
if self.sparse_indices is not None and self.sparse_residual is not None:
flat = weight.reshape(-1)
flat[self.sparse_indices.long()] += self.sparse_residual.float()
if self.lr_U is not None and self.lr_V is not None:
weight = weight + self.lr_U.float() @ self.lr_V.float()
return weight
@property
def num_params(self) -> int:
return self.original_shape[0] * self.original_shape[1]
@property
def sparse_nnz(self) -> int:
if self.sparse_indices is None:
return 0
return int(self.sparse_indices.numel())
@property
def effective_bits(self) -> float:
out_features, in_features = self.original_shape
n_groups = self.group_alpha.shape[1]
code_bits = 2 * out_features * in_features
group_param_bits = 16 * 2 * out_features * n_groups
sparse_bits = self.sparse_nnz * (32 + 16)
low_rank_bits = 0
if self.lr_U is not None and self.lr_V is not None:
rank = self.lr_U.shape[1]
low_rank_bits = 16 * rank * (out_features + in_features)
return (code_bits + group_param_bits + sparse_bits) / (
out_features * in_features
) + low_rank_bits / (out_features * in_features)
def compute_groupwise_error(
weight: torch.Tensor,
param: GroupwiseTernaryParameter,
) -> dict:
"""Compute reconstruction metrics for grouped ternary weights."""
weight_f = weight.float().cpu()
dequant = param.dequantize().float().cpu()
diff = weight_f - dequant
mse = diff.square().mean().item()
rmse = mse**0.5
rms_weight = weight_f.norm().item() / (weight_f.numel() ** 0.5) + 1e-8
rel_error = rmse / rms_weight
sparsity = (param.ternary_codes == 0).float().mean().item()
return {
"mse": mse,
"rmse": rmse,
"relative_error": rel_error,
"max_error": diff.abs().max().item(),
"sparsity": sparsity,
"effective_bits": param.effective_bits,
"sparse_nnz": param.sparse_nnz,
}
class GroupwiseAsymmetricTernaryQuantizer:
"""Per-group asymmetric ternary fitting with optional sparse residual."""
def __init__(
self,
group_size: int = 32,
n_iter: int = 10,
use_activation_aware: bool = True,
salient_fraction: float = 0.0,
low_rank_rank: int = 0,
low_rank_fit_mode: str = "weight_svd",
low_rank_ridge: float = 1e-4,
low_rank_max_samples: int = 4096,
importance_threshold_scale: float = 0.0,
):
self.group_size = group_size
self.n_iter = n_iter
self.use_activation_aware = use_activation_aware
self.salient_fraction = max(0.0, min(salient_fraction, 1.0))
self.low_rank_rank = max(0, int(low_rank_rank))
self.low_rank_fit_mode = low_rank_fit_mode
self.low_rank_ridge = max(0.0, float(low_rank_ridge))
self.low_rank_max_samples = max(0, int(low_rank_max_samples))
self.importance_threshold_scale = float(importance_threshold_scale)
def quantize(
self,
weight: torch.Tensor,
activations: Optional[torch.Tensor] = None,
outputs: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> GroupwiseTernaryParameter:
weight_f = weight.float()
work_device = weight_f.device
out_features, in_features = weight_f.shape
group_size = min(self.group_size, in_features)
n_groups = (in_features + group_size - 1) // group_size
activations_f = None
if activations is not None:
activations_f = activations.float().to(work_device)
outputs_f = None
if outputs is not None:
outputs_f = outputs.float().to(work_device)
act_weights = None
act_importance = None
if activations_f is not None and self.use_activation_aware:
act_weights = (activations_f ** 2).mean(dim=0)
act_weights = act_weights / (act_weights.mean() + 1e-8)
if activations_f is not None and self.importance_threshold_scale > 0:
# AWQ-inspired: mean |activation| per input channel → importance scores
act_importance = activations_f.abs().mean(dim=0)
act_importance = act_importance / (act_importance.mean() + 1e-8)
ternary_codes = torch.zeros_like(weight_f, dtype=torch.int8)
group_alpha = torch.zeros(
out_features,
n_groups,
dtype=torch.float32,
device=work_device,
)
group_mu = torch.zeros(
out_features,
n_groups,
dtype=torch.float32,
device=work_device,
)
for group_idx in range(n_groups):
start = group_idx * group_size
end = min(start + group_size, in_features)
weight_group = weight_f[:, start:end]
group_weights = (
act_weights[start:end] if act_weights is not None else None
)
group_importance = (
act_importance[start:end] if act_importance is not None else None
)
alpha, mu, ternary = self._fit_group(
weight_group, group_weights, group_importance, self.importance_threshold_scale
)
ternary_codes[:, start:end] = ternary.to(torch.int8)
group_alpha[:, group_idx] = alpha.squeeze(1)
group_mu[:, group_idx] = mu.squeeze(1)
sparse_indices = None
sparse_residual = None
lr_U = None
lr_V = None
if self.salient_fraction > 0.0:
base = self._dequantize_from_parts(
ternary_codes,
group_alpha,
group_mu,
group_size,
in_features,
)
sparse_indices, sparse_residual = self._select_salient_residual(
weight_f,
base,
act_weights,
)
else:
base = self._dequantize_from_parts(
ternary_codes,
group_alpha,
group_mu,
group_size,
in_features,
)
if sparse_indices is not None and sparse_residual is not None:
flat = base.reshape(-1)
flat[sparse_indices.long()] += sparse_residual.float()
if self.low_rank_rank > 0:
if (
self.low_rank_fit_mode == "activation_regression"
and activations_f is not None
and outputs_f is not None
):
lr_U, lr_V = self._fit_output_aware_low_rank_residual(
activations=activations_f,
outputs=outputs_f,
base_weight=base,
bias=bias,
)
if lr_U is None or lr_V is None:
lr_U, lr_V = self._fit_low_rank_residual(weight_f - base)
return GroupwiseTernaryParameter(
ternary_codes=ternary_codes.cpu(),
group_alpha=group_alpha.to(torch.float16).cpu(),
group_mu=group_mu.to(torch.float16).cpu(),
group_size=group_size,
sparse_indices=sparse_indices.cpu() if sparse_indices is not None else None,
sparse_residual=sparse_residual.cpu() if sparse_residual is not None else None,
lr_U=lr_U.cpu() if lr_U is not None else None,
lr_V=lr_V.cpu() if lr_V is not None else None,
original_shape=tuple(weight.shape),
original_dtype=weight.dtype,
)
def _fit_group(
self,
weight_group: torch.Tensor,
act_weights: Optional[torch.Tensor],
act_importance: Optional[torch.Tensor] = None,
importance_threshold_scale: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if act_weights is not None:
weights = act_weights.unsqueeze(0)
mu = (weight_group * weights).sum(dim=1, keepdim=True) / (
act_weights.sum() + 1e-8
)
alpha = ((weight_group - mu).abs() * weights).sum(dim=1, keepdim=True) / (
act_weights.sum() + 1e-8
)
else:
weights = torch.ones(1, weight_group.shape[1], device=weight_group.device)
mu = weight_group.mean(dim=1, keepdim=True)
alpha = (weight_group - mu).abs().mean(dim=1, keepdim=True)
# Clamp both ends: min to avoid div-by-zero, max to avoid FP16 overflow
# (inf in plane-0 scales causes inf-inf=NaN in residuals for planes 1+)
alpha = alpha.clamp(min=1e-8, max=65504.0)
ternary = self._round_to_ternary(weight_group, alpha, mu, act_importance, importance_threshold_scale)
for _ in range(self.n_iter):
alpha, mu = self._solve_alpha_mu(weight_group, ternary, weights)
ternary = self._round_to_ternary(weight_group, alpha, mu, act_importance, importance_threshold_scale)
return alpha, mu, ternary
@staticmethod
def _round_to_ternary(
weight_group: torch.Tensor,
alpha: torch.Tensor,
mu: torch.Tensor,
act_importance: Optional[torch.Tensor] = None,
importance_threshold_scale: float = 0.0,
) -> torch.Tensor:
normalized = (weight_group - mu) / alpha
if act_importance is not None and importance_threshold_scale > 0:
# AWQ-inspired: high-importance input channels get lower threshold,
# making them less likely to be zeroed out. Preserves signal where activations
# are large — exactly the insight from AWQ applied to ternary codes.
# thresh shape: [1, group_size], broadcast over output rows.
thresh = 0.5 / (
1.0 + act_importance.unsqueeze(0) * importance_threshold_scale
).clamp(max=4.0) # minimum threshold = 0.5/4.0 = 0.125
else:
thresh = 0.5
ternary = torch.zeros_like(normalized)
ternary[normalized > thresh] = 1.0
ternary[normalized < -thresh] = -1.0
return ternary
@staticmethod
def _solve_alpha_mu(
weight_group: torch.Tensor,
ternary: torch.Tensor,
weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
out_features = weight_group.shape[0]
weighted_t = weights * ternary
weighted_t2 = weighted_t * ternary
a11 = weighted_t2.sum(dim=1)
a12 = weighted_t.sum(dim=1)
a22 = weights.sum(dim=1).expand(out_features)
b1 = (weighted_t * weight_group).sum(dim=1)
b2 = (weights * weight_group).sum(dim=1)
det = (a11 * a22 - a12.square()).clamp(min=1e-10)
alpha = ((a22 * b1 - a12 * b2) / det).clamp(min=1e-8, max=65504.0).unsqueeze(1)
mu = ((a11 * b2 - a12 * b1) / det).clamp(min=-65504.0, max=65504.0).unsqueeze(1)
return alpha, mu
def _dequantize_from_parts(
self,
ternary_codes: torch.Tensor,
group_alpha: torch.Tensor,
group_mu: torch.Tensor,
group_size: int,
in_features: int,
) -> torch.Tensor:
alpha = group_alpha.repeat_interleave(group_size, dim=1)[:, :in_features]
mu = group_mu.repeat_interleave(group_size, dim=1)[:, :in_features]
return alpha * ternary_codes.float() + mu
def _select_salient_residual(
self,
weight: torch.Tensor,
base: torch.Tensor,
act_weights: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
residual = weight - base
if act_weights is not None:
sensitivity = residual.abs() * act_weights.unsqueeze(0)
else:
sensitivity = residual.abs()
k = max(1, int(self.salient_fraction * sensitivity.numel()))
indices = torch.topk(sensitivity.reshape(-1), k).indices.to(torch.int32)
residual_values = residual.reshape(-1)[indices.long()].to(torch.float16)
return indices, residual_values
def _fit_low_rank_residual(
self,
residual: torch.Tensor,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._factorize_low_rank(residual)
def _fit_output_aware_low_rank_residual(
self,
activations: torch.Tensor,
outputs: torch.Tensor,
base_weight: torch.Tensor,
bias: Optional[torch.Tensor],
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if activations is None or outputs is None:
return None, None
x = activations.float()
y = outputs.float()
if x.dim() != 2 or y.dim() != 2 or x.shape[0] != y.shape[0]:
return None, None
if self.low_rank_max_samples > 0 and x.shape[0] > self.low_rank_max_samples:
indices = torch.linspace(
0,
x.shape[0] - 1,
steps=self.low_rank_max_samples,
dtype=torch.float32,
).round().long().unique()
x = x.index_select(0, indices)
y = y.index_select(0, indices)
base_bias = bias.float() if bias is not None else None
base_outputs = F.linear(x, base_weight.float(), base_bias)
residual_outputs = y - base_outputs
if residual_outputs.abs().max().item() < 1e-8:
return None, None
xtx = x.T @ x
rhs = x.T @ residual_outputs
ridge_scale = xtx.diagonal().mean().item() if xtx.numel() else 0.0
ridge = self.low_rank_ridge * max(ridge_scale, 1e-8)
system = xtx + ridge * torch.eye(xtx.shape[0], dtype=xtx.dtype)
try:
delta_t = torch.linalg.solve(system, rhs)
except RuntimeError:
delta_t = torch.linalg.pinv(system) @ rhs
delta = delta_t.T.contiguous()
return self._factorize_low_rank(delta)
def _factorize_low_rank(
self,
residual: torch.Tensor,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
out_features, in_features = residual.shape
rank = min(self.low_rank_rank, out_features, in_features)
if rank <= 0:
return None, None
try:
if out_features >= in_features:
residual_t = residual.T.contiguous()
if min(residual_t.shape) > rank + 4:
q = min(rank + 8, min(residual_t.shape))
U, S, V = torch.svd_lowrank(residual_t, q=q, niter=2)
U = U[:, :rank]
S = S[:rank]
V = V[:, :rank]
lr_U = (V * S.unsqueeze(0)).to(torch.float16)
lr_V = U.T.to(torch.float16)
return lr_U, lr_V
U, S, Vh = torch.linalg.svd(residual_t, full_matrices=False)
lr_U = (Vh[:rank, :].T * S[:rank].unsqueeze(0)).to(torch.float16)
lr_V = U[:, :rank].T.to(torch.float16)
return lr_U, lr_V
if min(residual.shape) > rank + 4:
q = min(rank + 8, min(residual.shape))
U, S, V = torch.svd_lowrank(residual, q=q, niter=2)
U = U[:, :rank]
S = S[:rank]
V = V[:, :rank]
lr_U = (U * S.unsqueeze(0)).to(torch.float16)
lr_V = V.T.to(torch.float16)
return lr_U, lr_V
U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
except RuntimeError:
return None, None
lr_U = (U[:, :rank] * S[:rank].unsqueeze(0)).to(torch.float16)
lr_V = Vh[:rank, :].to(torch.float16)
return lr_U, lr_V
class TrainableLowRankLinear(nn.Module):
"""Frozen quantized base plus trainable low-rank residual."""
def __init__(
self,
quant_param: GroupwiseTernaryParameter,
bias: Optional[torch.Tensor] = None,
):
super().__init__()
if quant_param.lr_U is None or quant_param.lr_V is None:
raise ValueError("Trainable low-rank wrapper requires initialized lr_U/lr_V.")
base_weight = quant_param.dequantize().float() - (
quant_param.lr_U.float() @ quant_param.lr_V.float()
)
self.register_buffer("base_weight", base_weight)
self.U = nn.Parameter(quant_param.lr_U.float().clone())
self.V = nn.Parameter(quant_param.lr_V.float().clone())
if bias is not None:
self.register_buffer("bias", bias.float().clone())
else:
self.bias = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.base_weight.to(x.device, dtype=x.dtype)
delta = (self.U @ self.V).to(x.device, dtype=x.dtype)
bias = self.bias.to(x.device, dtype=x.dtype) if self.bias is not None else None
return F.linear(x, weight + delta, bias)
@dataclass
class SmallModelQuantizationConfig:
"""Configuration for role-aware small-model ternarization."""
group_size: int = 32
n_iter: int = 10
salient_fraction: float = 0.01
min_salient_fraction: float = 0.0025
max_salient_fraction: float = 0.01
adaptive_salient: bool = False
low_rank_rank: int = 0
adaptive_low_rank: bool = False
low_rank_chunk_rank: int = 16
low_rank_target_average_bits: Optional[float] = None
low_rank_fit_mode: str = "weight_svd"
low_rank_ridge: float = 1e-4
low_rank_max_samples: int = 4096
n_boundary_layers: int = 2
calibration_batch_size: int = 4
quantize_attention_output: bool = False
quantize_mlp_output: bool = False
target_average_bits: Optional[float] = None
role_cost_weights: dict[str, float] = field(
default_factory=lambda: {
"attention_inputs": 1.0,
"mlp_inputs": 1.0,
"attention_output": 1.05,
"mlp_output": 1.05,
}
)
base_config: QuantizationConfig = field(default_factory=QuantizationConfig)
importance_threshold_scale: float = 0.0
@dataclass
class RoleAwareModulePolicy:
"""Planner decision for a single module."""
name: str
layer_idx: int
role: str
out_features: int
in_features: int
num_params: int
action: str
reason: str
boundary_layer: bool
role_priority: int
activation_rms: float
relative_error: float
sensitivity_score: float
bits_if_quantized: float
salient_fraction: float
low_rank_rank: int
@dataclass
class RoleAwareQuantizationPlan:
"""Explicit plan for the role-aware ternary path."""
method_name: str
target_average_bits: Optional[float]
predicted_average_bits: float
predicted_quantized_average_bits: float
predicted_quantized_fraction: float
total_model_params: int
boundary_layers: list[int]
policies: dict[str, RoleAwareModulePolicy]
@dataclass
class SmallModelQuantizationResult:
"""Result of applying a role-aware plan in-place."""
quantized_params: dict[str, GroupwiseTernaryParameter]
stats: dict[str, dict]
quantized_module_names: list[str]
protected_module_names: list[str]
skipped_layer_indices: list[int]
plan: RoleAwareQuantizationPlan
calibration_tune: Optional[dict] = None
def quantize_small_model_inplace(
model: nn.Module,
calibration_data: torch.Tensor,
config: Optional[SmallModelQuantizationConfig] = None,
plan: Optional[RoleAwareQuantizationPlan] = None,
) -> SmallModelQuantizationResult:
"""Quantize a model in-place with a role-aware small-model plan."""
if config is None:
config = SmallModelQuantizationConfig()
if plan is None:
plan = build_role_aware_plan(model, calibration_data, config)
decoder_layers, layer_path = _get_decoder_layers(model)
quantized_params: dict[str, GroupwiseTernaryParameter] = {}
stats: dict[str, dict] = {}
quantized_module_names: list[str] = []
protected_module_names: list[str] = []
for layer_idx, layer in enumerate(decoder_layers):
layer_prefix = f"{layer_path}.{layer_idx}"
groups, protected = _build_dependency_groups_for_plan(
layer,
layer_prefix,
layer_idx,
plan,
config,
)
protected_module_names.extend(protected)
for group in groups:
need_outputs = (
config.low_rank_fit_mode == "activation_regression"
and any(plan.policies[name].low_rank_rank > 0 for name in group)
)
if need_outputs:
io_captures = _capture_module_io(
model,
calibration_data,
group,
config.calibration_batch_size,
)
else:
captures = _capture_activations(
model,
calibration_data,
group,
config.calibration_batch_size,
)
for name, module in group.items():
policy = plan.policies[name]
quantizer = GroupwiseAsymmetricTernaryQuantizer(
group_size=config.group_size,
n_iter=config.n_iter,
use_activation_aware=config.base_config.use_activation_aware,
salient_fraction=policy.salient_fraction,
low_rank_rank=policy.low_rank_rank,
low_rank_fit_mode=config.low_rank_fit_mode,
low_rank_ridge=config.low_rank_ridge,
low_rank_max_samples=config.low_rank_max_samples,
importance_threshold_scale=config.importance_threshold_scale,
)
if need_outputs:
activations = io_captures[name]["inputs"]
outputs = io_captures[name]["outputs"]
else:
activations = captures[name]
outputs = None
weight = _get_weight(module)
bias = getattr(module, "bias", None)
param = quantizer.quantize(
weight.detach().float().cpu(),
activations=activations.float().cpu()
if activations is not None
else None,
outputs=outputs.float().cpu() if outputs is not None else None,
bias=bias.detach().float().cpu() if bias is not None else None,
)
quantized_params[name] = param
stats[name] = compute_groupwise_error(weight, param)
quantized_module_names.append(name)
dequant = param.dequantize().to(weight.device, dtype=weight.dtype)
_set_weight(module, dequant)
skipped_layer_indices = [
idx
for idx in range(len(decoder_layers))
if idx < config.n_boundary_layers
or idx >= len(decoder_layers) - config.n_boundary_layers
]
for name, policy in plan.policies.items():
if policy.action != "groupwise_ternary":
protected_module_names.append(name)
return SmallModelQuantizationResult(
quantized_params=quantized_params,
stats=stats,
quantized_module_names=quantized_module_names,
protected_module_names=sorted(set(protected_module_names)),
skipped_layer_indices=skipped_layer_indices,
plan=plan,
)
def build_role_aware_plan(
model: nn.Module,
calibration_data: torch.Tensor,
config: Optional[SmallModelQuantizationConfig] = None,
) -> RoleAwareQuantizationPlan:
"""Build an explicit role-aware allocation plan before quantization."""
if config is None:
config = SmallModelQuantizationConfig()
decoder_layers, layer_path = _get_decoder_layers(model)
total_model_params = sum(p.numel() for p in model.parameters())
boundary_layers = [
idx
for idx in range(len(decoder_layers))
if idx < config.n_boundary_layers
or idx >= len(decoder_layers) - config.n_boundary_layers
]
estimation_fraction = (
config.min_salient_fraction
if config.adaptive_salient
else config.salient_fraction
)
policies: dict[str, RoleAwareModulePolicy] = {}
candidates: list[RoleAwareModulePolicy] = []
residual_spectra: dict[str, torch.Tensor] = {}
predicted_bits = 16.0 * total_model_params
for layer_idx, layer in enumerate(decoder_layers):
layer_prefix = f"{layer_path}.{layer_idx}"
modules = _collect_role_modules(layer, layer_prefix, config.base_config)
if not modules:
continue
if config.low_rank_fit_mode == "activation_regression" and config.low_rank_rank > 0:
io_captures = _capture_module_io(
model,
calibration_data,
modules,
config.calibration_batch_size,
)
captures = {name: payload["inputs"] for name, payload in io_captures.items()}
outputs = {name: payload["outputs"] for name, payload in io_captures.items()}
else:
captures = _capture_activations(
model,
calibration_data,
modules,
config.calibration_batch_size,
)
outputs = {}
for name, module in modules.items():
role = _classify_module(name)
if role is None:
continue
weight = _get_weight(module).detach().float().cpu()
activations = captures[name]
activation_rms = _activation_rms(activations)
role_priority = ROLE_ORDER.index(role)
boundary = layer_idx in boundary_layers
if boundary:
policy = RoleAwareModulePolicy(
name=name,
layer_idx=layer_idx,
role=role,
out_features=weight.shape[0],
in_features=weight.shape[1],
num_params=weight.numel(),
action="fp16",
reason="boundary_layer",
boundary_layer=True,
role_priority=role_priority,
activation_rms=activation_rms,
relative_error=0.0,
sensitivity_score=float("inf"),
bits_if_quantized=16.0,
salient_fraction=0.0,
low_rank_rank=0,
)
policies[name] = policy
continue
estimate_low_rank = 0 if config.adaptive_low_rank else config.low_rank_rank
quantizer = GroupwiseAsymmetricTernaryQuantizer(
group_size=config.group_size,
n_iter=config.n_iter,
use_activation_aware=config.base_config.use_activation_aware,
salient_fraction=estimation_fraction,
low_rank_rank=estimate_low_rank,
low_rank_fit_mode=config.low_rank_fit_mode,
low_rank_ridge=config.low_rank_ridge,
low_rank_max_samples=config.low_rank_max_samples,
importance_threshold_scale=config.importance_threshold_scale,
)
param = quantizer.quantize(
weight,
activations=activations.float().cpu()
if activations is not None
else None,
outputs=outputs.get(name).float().cpu()
if outputs.get(name) is not None
else None,
bias=module.bias.detach().float().cpu()
if getattr(module, "bias", None) is not None
else None,
)
stats = compute_groupwise_error(weight, param)
sensitivity = (
stats["relative_error"]
* max(activation_rms, 1e-6)
* config.role_cost_weights.get(role, 1.0)
)
if config.adaptive_low_rank and config.low_rank_rank > 0:
residual = weight - param.dequantize()
residual_spectra[name] = _estimate_residual_spectrum(
residual,
max_rank=config.low_rank_rank,
)
policy = RoleAwareModulePolicy(
name=name,
layer_idx=layer_idx,
role=role,
out_features=weight.shape[0],
in_features=weight.shape[1],
num_params=weight.numel(),
action="fp16",
reason="unallocated",
boundary_layer=False,
role_priority=role_priority,
activation_rms=activation_rms,
relative_error=stats["relative_error"],
sensitivity_score=sensitivity,
bits_if_quantized=param.effective_bits,
salient_fraction=estimation_fraction,
low_rank_rank=0 if config.adaptive_low_rank else config.low_rank_rank,
)
policies[name] = policy
candidates.append(policy)
if config.target_average_bits is None:
_assign_practical_actions(candidates, config)
else:
predicted_bits = _assign_budgeted_actions(
candidates,
config,
total_model_params,
)
if config.target_average_bits is None:
predicted_bits = _predict_total_bits(candidates, total_model_params)
_assign_salient_fractions(candidates, config)
if config.adaptive_low_rank and config.low_rank_rank > 0:
_assign_adaptive_low_rank(
candidates,
config,
total_model_params,
residual_spectra,
)
predicted_bits = _predict_total_bits(candidates, total_model_params)
quantized = [policy for policy in candidates if policy.action == "groupwise_ternary"]
quantized_params = sum(policy.num_params for policy in quantized)
quantized_weighted_bits = 0.0
if quantized_params > 0:
quantized_weighted_bits = sum(
policy.num_params * policy.bits_if_quantized for policy in quantized
) / quantized_params
return RoleAwareQuantizationPlan(
method_name="RAST-small",
target_average_bits=config.target_average_bits,
predicted_average_bits=predicted_bits / total_model_params,
predicted_quantized_average_bits=quantized_weighted_bits,
predicted_quantized_fraction=quantized_params / total_model_params
if total_model_params
else 0.0,
total_model_params=total_model_params,
boundary_layers=boundary_layers,
policies=policies,
)
def build_all_non_boundary_plan(
model: nn.Module,
calibration_data: torch.Tensor,
config: Optional[SmallModelQuantizationConfig] = None,
) -> RoleAwareQuantizationPlan:
"""Build an aggressive baseline that ternarizes every non-boundary role module."""
if config is None:
config = SmallModelQuantizationConfig()
plan = build_role_aware_plan(model, calibration_data, config)
candidates = [policy for policy in plan.policies.values() if not policy.boundary_layer]
for policy in candidates:
policy.action = "groupwise_ternary"
policy.reason = "all_non_boundary_baseline"
policy.salient_fraction = config.salient_fraction
policy.bits_if_quantized = _estimate_effective_bits(
policy,
config.group_size,
policy.salient_fraction,
)
predicted_bits = _predict_total_bits(candidates, plan.total_model_params)
quantized_params = sum(policy.num_params for policy in candidates)
quantized_weighted_bits = 0.0
if quantized_params > 0:
quantized_weighted_bits = sum(
policy.num_params * policy.bits_if_quantized for policy in candidates
) / quantized_params
plan.method_name = "PB-style sparse ternary"
plan.predicted_average_bits = predicted_bits / plan.total_model_params
plan.predicted_quantized_average_bits = quantized_weighted_bits
plan.predicted_quantized_fraction = (
quantized_params / plan.total_model_params if plan.total_model_params else 0.0
)
return plan
def build_sensitivity_only_plan(
model: nn.Module,
calibration_data: torch.Tensor,
config: Optional[SmallModelQuantizationConfig] = None,
) -> RoleAwareQuantizationPlan:
"""Build a matched-bit baseline that ignores role structure during allocation."""
if config is None:
config = SmallModelQuantizationConfig(target_average_bits=10.5)
if config.target_average_bits is None:
raise ValueError("Sensitivity-only planning requires target_average_bits.")
plan = build_role_aware_plan(model, calibration_data, config)
candidates = []
target_bits = float(config.target_average_bits) * plan.total_model_params
for policy in plan.policies.values():
if policy.boundary_layer:
continue
policy.action = "groupwise_ternary"
policy.salient_fraction = (
config.min_salient_fraction if config.adaptive_salient else config.salient_fraction
)
policy.bits_if_quantized = _estimate_effective_bits(
policy,
config.group_size,
policy.salient_fraction,
)
policy.action = "fp16"
policy.reason = "sensitivity_only_reset"
candidates.append(policy)
predicted_bits = 16.0 * plan.total_model_params
ranked = sorted(
candidates,
key=lambda policy: (
policy.sensitivity_score / max(16.0 - policy.bits_if_quantized, 1e-6),
policy.sensitivity_score,
policy.layer_idx,
policy.name,
),
)
for policy in ranked:
if predicted_bits <= target_bits:
policy.reason = "sensitivity_budget_not_needed"
continue
policy.action = "groupwise_ternary"
policy.reason = "sensitivity_budget_allocation"
predicted_bits -= (16.0 - policy.bits_if_quantized) * policy.num_params
_assign_salient_fractions(candidates, config)
predicted_bits = _predict_total_bits(candidates, plan.total_model_params)
quantized = [policy for policy in candidates if policy.action == "groupwise_ternary"]
quantized_params = sum(policy.num_params for policy in quantized)
quantized_weighted_bits = 0.0
if quantized_params > 0:
quantized_weighted_bits = sum(
policy.num_params * policy.bits_if_quantized for policy in quantized
) / quantized_params
plan.method_name = "Sensitivity-only ternary budget"
plan.predicted_average_bits = predicted_bits / plan.total_model_params
plan.predicted_quantized_average_bits = quantized_weighted_bits
plan.predicted_quantized_fraction = (
quantized_params / plan.total_model_params if plan.total_model_params else 0.0
)
return plan
def summarize_small_model_quantization(
result: SmallModelQuantizationResult,
model: nn.Module,
) -> dict:
"""Compute summary statistics for the mixed ternary path."""
total_model_params = result.plan.total_model_params
quantized_params = sum(param.num_params for param in result.quantized_params.values())
avg_sparsity = 0.0
avg_rel_error = 0.0
avg_bits = 0.0
total_sparse_nnz = 0
if result.stats:
avg_sparsity = sum(s["sparsity"] for s in result.stats.values()) / len(
result.stats
)
avg_rel_error = sum(
s["relative_error"] for s in result.stats.values()
) / len(result.stats)
avg_bits = sum(s["effective_bits"] for s in result.stats.values()) / len(
result.stats
)
total_sparse_nnz = sum(
int(s.get("sparse_nnz", 0)) for s in result.stats.values()
)
full_model_bits = 16.0 * total_model_params
for name, param in result.quantized_params.items():
full_model_bits -= 16.0 * param.num_params
full_model_bits += param.effective_bits * param.num_params
summary = {
"method_name": result.plan.method_name,
"quantized_params": quantized_params,
"total_model_params": total_model_params,
"quantized_fraction": quantized_params / total_model_params
if total_model_params
else 0.0,
"n_quantized_modules": len(result.quantized_module_names),
"n_protected_modules": len(result.protected_module_names),
"skipped_layer_indices": result.skipped_layer_indices,
"avg_sparsity": avg_sparsity,
"avg_relative_error": avg_rel_error,
"avg_effective_bits": avg_bits,
"full_model_effective_bits": full_model_bits / total_model_params
if total_model_params
else 0.0,
"total_sparse_nnz": total_sparse_nnz,
"target_average_bits": result.plan.target_average_bits,
"predicted_average_bits": result.plan.predicted_average_bits,
}
if result.calibration_tune is not None:
summary["calibration_tune"] = result.calibration_tune
return summary
def plan_to_dict(plan: RoleAwareQuantizationPlan) -> dict:
"""Convert a role-aware plan to a JSON-serializable dict."""
return {
"method_name": plan.method_name,
"target_average_bits": plan.target_average_bits,
"predicted_average_bits": plan.predicted_average_bits,
"predicted_quantized_average_bits": plan.predicted_quantized_average_bits,
"predicted_quantized_fraction": plan.predicted_quantized_fraction,
"total_model_params": plan.total_model_params,
"boundary_layers": plan.boundary_layers,
"policies": {
name: asdict(policy)
for name, policy in sorted(plan.policies.items())
},
}
def config_to_dict(config: SmallModelQuantizationConfig) -> dict:
"""Convert config to a JSON-friendly dict."""
data = asdict(config)
if "base_config" in data:
data["base_config"] = asdict(config.base_config)
return data
def tune_low_rank_residuals_inplace(
model: nn.Module,
result: SmallModelQuantizationResult,
calibration_data: torch.Tensor,
n_steps: int = 0,
lr: float = 5e-5,
batch_size: int = 2,
max_seq_len: Optional[int] = None,
max_grad_norm: float = 1.0,
eval_interval: int = 5,
val_fraction: float = 0.25,
behavior_sequences: Optional[list[torch.Tensor]] = None,
behavior_weight: float = 0.0,
calibration_hidden_states: Optional[torch.Tensor] = None,
behavior_hidden_states: Optional[list[torch.Tensor]] = None,
calibration_logit_targets: Optional[dict[str, torch.Tensor]] = None,
behavior_logit_targets: Optional[list[Optional[dict[str, torch.Tensor]]]] = None,
distill_weight: float = 0.0,
behavior_hidden_weight: float = 0.0,
logit_distill_weight: float = 0.0,
behavior_logit_weight: float = 0.0,
entropy_distill_weight: float = 0.0,
behavior_entropy_weight: float = 0.0,
logit_distill_temperature: float = 1.0,
seed: int = 42,
) -> dict:
"""Calibrate low-rank residuals with LM loss plus optional teacher-guided distillation."""
if n_steps <= 0:
stats = {
"steps": 0,
"n_wrapped_modules": 0,
"n_trainable_params": 0,
}
result.calibration_tune = stats
return stats
wrapped: dict[str, TrainableLowRankLinear] = {}
trainable_params: list[nn.Parameter] = []
module_device = None
for name, quant_param in result.quantized_params.items():
if quant_param.lr_U is None or quant_param.lr_V is None:
continue
parent, target_name = _resolve_parent_module(model, name)
original = getattr(parent, target_name)
bias = (
original.bias.detach().float().cpu()
if getattr(original, "bias", None) is not None
else None
)
if module_device is None:
module_device = _get_weight(original).device
wrapper = TrainableLowRankLinear(quant_param, bias=bias).to(module_device)
setattr(parent, target_name, wrapper)
wrapped[name] = wrapper
trainable_params.extend([wrapper.U, wrapper.V])
if not trainable_params:
stats = {
"steps": 0,
"n_wrapped_modules": 0,
"n_trainable_params": 0,
}
result.calibration_tune = stats
return stats
optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.0)
losses = []
data = calibration_data
if max_seq_len is not None and data.shape[1] > max_seq_len:
data = data[:, :max_seq_len]
behavior_weight = max(0.0, min(float(behavior_weight), 1.0))
distill_weight = max(0.0, float(distill_weight))
behavior_hidden_weight = max(0.0, float(behavior_hidden_weight))
logit_distill_weight = max(0.0, float(logit_distill_weight))
behavior_logit_weight = max(0.0, float(behavior_logit_weight))
entropy_distill_weight = max(0.0, float(entropy_distill_weight))
behavior_entropy_weight = max(0.0, float(behavior_entropy_weight))
logit_distill_temperature = max(1e-4, float(logit_distill_temperature))
teacher_sequences = []
teacher_hidden_sequences = []
teacher_logit_sequences = []
if behavior_sequences:
for idx, seq in enumerate(behavior_sequences):
current = seq.detach().cpu().clone()
if current.dim() == 1:
current = current.unsqueeze(0)
if max_seq_len is not None and current.shape[1] > max_seq_len:
current = current[:, :max_seq_len]
if current.numel() > 0:
teacher_sequences.append(current)
if behavior_hidden_states and idx < len(behavior_hidden_states):
hidden = behavior_hidden_states[idx].detach().cpu().clone()
if hidden.dim() == 2:
hidden = hidden.unsqueeze(0)
if max_seq_len is not None and hidden.shape[1] > max_seq_len:
hidden = hidden[:, :max_seq_len]
teacher_hidden_sequences.append(hidden)
else:
teacher_hidden_sequences.append(None)
if behavior_logit_targets and idx < len(behavior_logit_targets):
logit_target = {
"indices": behavior_logit_targets[idx]["indices"].detach().cpu().clone(),
"logits": behavior_logit_targets[idx]["logits"].detach().cpu().clone(),
"entropy": behavior_logit_targets[idx]
.get("entropy")
.detach()
.cpu()
.clone()
if behavior_logit_targets[idx].get("entropy") is not None
else None,
}
if max_seq_len is not None and logit_target["indices"].shape[1] > max_seq_len - 1:
limit = max(max_seq_len - 1, 0)
logit_target["indices"] = logit_target["indices"][:, :limit]
logit_target["logits"] = logit_target["logits"][:, :limit]
if logit_target["entropy"] is not None:
logit_target["entropy"] = logit_target["entropy"][:, :limit]
teacher_logit_sequences.append(logit_target)
else:
teacher_logit_sequences.append(None)
hidden_targets = None
if calibration_hidden_states is not None:
hidden_targets = calibration_hidden_states.detach().cpu().clone()
if hidden_targets.dim() == 2:
hidden_targets = hidden_targets.unsqueeze(0)
if max_seq_len is not None and hidden_targets.shape[1] > max_seq_len:
hidden_targets = hidden_targets[:, :max_seq_len]
logit_targets = None
if calibration_logit_targets is not None:
logit_targets = {
"indices": calibration_logit_targets["indices"].detach().cpu().clone(),
"logits": calibration_logit_targets["logits"].detach().cpu().clone(),
"entropy": calibration_logit_targets.get("entropy").detach().cpu().clone()
if calibration_logit_targets.get("entropy") is not None
else None,
}
if max_seq_len is not None and logit_targets["indices"].shape[1] > max_seq_len - 1:
limit = max(max_seq_len - 1, 0)
logit_targets["indices"] = logit_targets["indices"][:, :limit]
logit_targets["logits"] = logit_targets["logits"][:, :limit]
if logit_targets["entropy"] is not None:
logit_targets["entropy"] = logit_targets["entropy"][:, :limit]
n_val = int(data.shape[0] * val_fraction)
if data.shape[0] >= 4:
n_val = max(1, n_val)
else:
n_val = 0
if n_val > 0 and n_val < data.shape[0]:
train_data = data[:-n_val]
val_data = data[-n_val:]
train_hidden = hidden_targets[:-n_val] if hidden_targets is not None else None
val_hidden = hidden_targets[-n_val:] if hidden_targets is not None else None
train_logit = (
{
"indices": logit_targets["indices"][:-n_val],
"logits": logit_targets["logits"][:-n_val],
"entropy": logit_targets["entropy"][:-n_val]
if logit_targets["entropy"] is not None
else None,
}
if logit_targets is not None
else None
)
val_logit = (
{
"indices": logit_targets["indices"][-n_val:],
"logits": logit_targets["logits"][-n_val:],
"entropy": logit_targets["entropy"][-n_val:]
if logit_targets["entropy"] is not None
else None,
}
if logit_targets is not None
else None
)
else:
train_data = data
val_data = None
train_hidden = hidden_targets
val_hidden = None
train_logit = logit_targets
val_logit = None
original_use_cache = getattr(model.config, "use_cache", None)
if original_use_cache is not None:
model.config.use_cache = False
torch.manual_seed(int(seed))
best_state = None
best_metric = None
best_val_loss = None
best_behavior_loss = None
best_hidden_loss = None
best_behavior_hidden_loss = None
best_logit_loss = None
best_behavior_logit_loss = None
best_entropy_loss = None
best_behavior_entropy_loss = None
behavior_losses = []
hidden_losses = []
behavior_hidden_losses = []
logit_losses = []
behavior_logit_losses = []
entropy_losses = []
behavior_entropy_losses = []
teacher_index = 0
def _slice_logit_targets(
targets: Optional[dict[str, torch.Tensor]],
indices: torch.Tensor,
) -> Optional[dict[str, torch.Tensor]]:
if targets is None:
return None
idx_cpu = indices.detach().cpu()
return {
"indices": targets["indices"].index_select(0, idx_cpu),
"logits": targets["logits"].index_select(0, idx_cpu),
"entropy": targets["entropy"].index_select(0, idx_cpu)
if targets.get("entropy") is not None
else None,
}
def _topk_logit_kl(
student_logits: torch.Tensor,
targets: Optional[dict[str, torch.Tensor]],
) -> Optional[torch.Tensor]:
if targets is None:
return None
target_indices = targets["indices"]
target_logits = targets["logits"]
if target_indices.numel() == 0 or target_logits.numel() == 0:
return None
seq_len = min(student_logits.shape[1] - 1, target_indices.shape[1])
if seq_len <= 0:
return None
gather_indices = target_indices[:, :seq_len].to(student_logits.device).long()
teacher_topk_logits = (
target_logits[:, :seq_len].to(student_logits.device).float()
)
student_next_logits = student_logits[:, :seq_len, :].float()
student_topk_logits = torch.gather(student_next_logits, -1, gather_indices)
student_log_probs = F.log_softmax(
student_topk_logits / logit_distill_temperature,
dim=-1,
)
teacher_probs = F.softmax(
teacher_topk_logits / logit_distill_temperature,
dim=-1,
)
return (
F.kl_div(
student_log_probs,
teacher_probs,
reduction="none",
).sum(dim=-1).mean()
* (logit_distill_temperature**2)
)
def _entropy_floor_loss(
student_logits: torch.Tensor,
targets: Optional[dict[str, torch.Tensor]],
) -> Optional[torch.Tensor]:
if targets is None or targets.get("entropy") is None:
return None
target_entropy = targets["entropy"]
if target_entropy.numel() == 0:
return None
seq_len = min(student_logits.shape[1] - 1, target_entropy.shape[1])
if seq_len <= 0:
return None
student_next_logits = student_logits[:, :seq_len, :].float()
student_log_probs = F.log_softmax(student_next_logits, dim=-1)
student_probs = student_log_probs.exp()
normalizer = math.log(max(student_next_logits.shape[-1], 2))
student_entropy = -(student_probs * student_log_probs).sum(dim=-1) / normalizer
teacher_entropy = target_entropy[:, :seq_len].to(student_logits.device).float()
return F.relu(teacher_entropy - student_entropy).square().mean()
def evaluate_loss(split: torch.Tensor) -> float:
if split is None or split.numel() == 0:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for start in range(0, split.shape[0], batch_size):
batch = split[start : start + batch_size]
values.append(float(model(batch, labels=batch).loss.item()))
model.train()
return sum(values) / max(len(values), 1)
def evaluate_logit_loss(
split: Optional[torch.Tensor],
targets: Optional[dict[str, torch.Tensor]],
) -> float:
if split is None or targets is None or split.numel() == 0:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for start in range(0, split.shape[0], batch_size):
batch = split[start : start + batch_size]
batch_targets = {
"indices": targets["indices"][start : start + batch.shape[0]],
"logits": targets["logits"][start : start + batch.shape[0]],
}
logits = model(batch).logits
loss = _topk_logit_kl(logits, batch_targets)
if loss is not None:
values.append(float(loss.item()))
model.train()
return sum(values) / max(len(values), 1) if values else float("nan")
def evaluate_entropy_loss(
split: Optional[torch.Tensor],
targets: Optional[dict[str, torch.Tensor]],
) -> float:
if split is None or targets is None or targets.get("entropy") is None or split.numel() == 0:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for start in range(0, split.shape[0], batch_size):
batch = split[start : start + batch_size]
batch_targets = {
"entropy": targets["entropy"][start : start + batch.shape[0]],
}
logits = model(batch).logits
loss = _entropy_floor_loss(logits, batch_targets)
if loss is not None:
values.append(float(loss.item()))
model.train()
return sum(values) / max(len(values), 1) if values else float("nan")
def evaluate_behavior_loss(sequences: list[torch.Tensor]) -> float:
if not sequences:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for seq in sequences:
batch = seq.to(module_device)
values.append(float(model(batch, labels=batch).loss.item()))
model.train()
return sum(values) / max(len(values), 1)
def evaluate_hidden_loss(
split: Optional[torch.Tensor],
targets: Optional[torch.Tensor],
) -> float:
if split is None or targets is None or split.numel() == 0:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for start in range(0, split.shape[0], batch_size):
batch = split[start : start + batch_size]
batch_targets = targets[start : start + batch.shape[0]].to(module_device)
outputs = model(batch, output_hidden_states=True)
student = outputs.hidden_states[-1].float()
values.append(float(F.mse_loss(student, batch_targets.float()).item()))
model.train()
return sum(values) / max(len(values), 1)
def evaluate_behavior_hidden_loss(
sequences: list[torch.Tensor],
targets: list[Optional[torch.Tensor]],
) -> float:
pairs = [
(seq, target)
for seq, target in zip(sequences, targets)
if target is not None
]
if not pairs:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for seq, target in pairs:
batch = seq.to(module_device)
outputs = model(batch, output_hidden_states=True)
student = outputs.hidden_states[-1].float()
values.append(float(F.mse_loss(student, target.to(module_device).float()).item()))
model.train()
return sum(values) / max(len(values), 1)
def evaluate_behavior_logit_loss(
sequences: list[torch.Tensor],
targets: list[Optional[dict[str, torch.Tensor]]],
) -> float:
pairs = [
(seq, target)
for seq, target in zip(sequences, targets)
if target is not None
]
if not pairs:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for seq, target in pairs:
logits = model(seq.to(module_device)).logits
loss = _topk_logit_kl(logits, target)
if loss is not None:
values.append(float(loss.item()))
model.train()
return sum(values) / max(len(values), 1) if values else float("nan")
def evaluate_behavior_entropy_loss(
sequences: list[torch.Tensor],
targets: list[Optional[dict[str, torch.Tensor]]],
) -> float:
pairs = [
(seq, target)
for seq, target in zip(sequences, targets)
if target is not None and target.get("entropy") is not None
]
if not pairs:
return float("nan")
model.eval()
values = []
with torch.no_grad():
for seq, target in pairs:
logits = model(seq.to(module_device)).logits
loss = _entropy_floor_loss(logits, target)
if loss is not None:
values.append(float(loss.item()))
model.train()
return sum(values) / max(len(values), 1) if values else float("nan")
model.train()
need_hidden_outputs = (
distill_weight > 0.0
or (teacher_sequences and behavior_hidden_weight > 0.0)
)
for step in range(n_steps):
idx = torch.randint(0, train_data.shape[0], (batch_size,), device=train_data.device)
batch = train_data.index_select(0, idx)
optimizer.zero_grad(set_to_none=True)
outputs = model(
batch,
labels=batch,
output_hidden_states=need_hidden_outputs,
)
lm_loss = outputs.loss
loss = lm_loss
hidden_loss = None
logit_loss = None
entropy_loss = None
if train_hidden is not None and distill_weight > 0.0:
target_hidden = train_hidden.index_select(0, idx.cpu()).to(module_device)
hidden_loss = F.mse_loss(
outputs.hidden_states[-1].float(),
target_hidden.float(),
)
if train_logit is not None and logit_distill_weight > 0.0:
batch_targets = _slice_logit_targets(train_logit, idx)
logit_loss = _topk_logit_kl(outputs.logits, batch_targets)
if train_logit is not None and entropy_distill_weight > 0.0:
batch_targets = _slice_logit_targets(train_logit, idx)
entropy_loss = _entropy_floor_loss(outputs.logits, batch_targets)
behavior_loss = None
behavior_hidden_loss = None
behavior_logit_loss = None
behavior_entropy_loss = None
if teacher_sequences and (
behavior_weight > 0.0
or behavior_hidden_weight > 0.0
or behavior_logit_weight > 0.0
or behavior_entropy_weight > 0.0
):
behavior_batch = teacher_sequences[teacher_index % len(teacher_sequences)].to(
module_device
)
behavior_hidden_target = teacher_hidden_sequences[
teacher_index % len(teacher_hidden_sequences)
]
behavior_logit_target = teacher_logit_sequences[
teacher_index % len(teacher_logit_sequences)
]
teacher_index += 1
behavior_outputs = model(
behavior_batch,
labels=behavior_batch,
output_hidden_states=behavior_hidden_target is not None
and behavior_hidden_weight > 0.0,
)
behavior_loss = behavior_outputs.loss
if behavior_hidden_target is not None and behavior_hidden_weight > 0.0:
behavior_hidden_loss = F.mse_loss(
behavior_outputs.hidden_states[-1].float(),
behavior_hidden_target.to(module_device).float(),
)
if behavior_logit_target is not None and behavior_logit_weight > 0.0:
behavior_logit_loss = _topk_logit_kl(
behavior_outputs.logits,
behavior_logit_target,
)
if behavior_logit_target is not None and behavior_entropy_weight > 0.0:
behavior_entropy_loss = _entropy_floor_loss(
behavior_outputs.logits,
behavior_logit_target,
)
if behavior_weight >= 1.0 and hidden_loss is None:
loss = behavior_loss
else:
loss = (1.0 - behavior_weight) * lm_loss + behavior_weight * behavior_loss
if hidden_loss is not None:
loss = loss + distill_weight * hidden_loss
if logit_loss is not None:
loss = loss + logit_distill_weight * logit_loss
if entropy_loss is not None:
loss = loss + entropy_distill_weight * entropy_loss
if behavior_hidden_loss is not None:
loss = loss + behavior_hidden_weight * behavior_hidden_loss
if behavior_logit_loss is not None:
loss = loss + behavior_logit_weight * behavior_logit_loss
if behavior_entropy_loss is not None:
loss = loss + behavior_entropy_weight * behavior_entropy_loss
loss.backward()
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
optimizer.step()
losses.append(float(loss.item()))
if behavior_loss is not None:
behavior_losses.append(float(behavior_loss.item()))
if hidden_loss is not None:
hidden_losses.append(float(hidden_loss.item()))
if behavior_hidden_loss is not None:
behavior_hidden_losses.append(float(behavior_hidden_loss.item()))
if logit_loss is not None:
logit_losses.append(float(logit_loss.item()))
if behavior_logit_loss is not None:
behavior_logit_losses.append(float(behavior_logit_loss.item()))
if entropy_loss is not None:
entropy_losses.append(float(entropy_loss.item()))
if behavior_entropy_loss is not None:
behavior_entropy_losses.append(float(behavior_entropy_loss.item()))
if val_data is not None and ((step + 1) % max(eval_interval, 1) == 0 or step == 0):
val_loss = evaluate_loss(val_data)
behavior_val_loss = evaluate_behavior_loss(teacher_sequences)
hidden_val_loss = evaluate_hidden_loss(val_data, val_hidden)
behavior_hidden_val_loss = evaluate_behavior_hidden_loss(
teacher_sequences,
teacher_hidden_sequences,
)
logit_val_loss = evaluate_logit_loss(val_data, val_logit)
entropy_val_loss = evaluate_entropy_loss(val_data, val_logit)
behavior_logit_val_loss = evaluate_behavior_logit_loss(
teacher_sequences,
teacher_logit_sequences,
)
behavior_entropy_val_loss = evaluate_behavior_entropy_loss(
teacher_sequences,
teacher_logit_sequences,
)
monitor = val_loss
if teacher_sequences and not math.isnan(behavior_val_loss):
if math.isnan(monitor):
monitor = behavior_val_loss
else:
monitor = (1.0 - behavior_weight) * monitor + behavior_weight * behavior_val_loss
if not math.isnan(hidden_val_loss):
if math.isnan(monitor):
monitor = distill_weight * hidden_val_loss
else:
monitor = monitor + distill_weight * hidden_val_loss
if not math.isnan(behavior_hidden_val_loss):
if math.isnan(monitor):
monitor = behavior_hidden_weight * behavior_hidden_val_loss
else:
monitor = monitor + behavior_hidden_weight * behavior_hidden_val_loss
if not math.isnan(logit_val_loss):
if math.isnan(monitor):
monitor = logit_distill_weight * logit_val_loss
else:
monitor = monitor + logit_distill_weight * logit_val_loss
if not math.isnan(entropy_val_loss):
if math.isnan(monitor):
monitor = entropy_distill_weight * entropy_val_loss
else:
monitor = monitor + entropy_distill_weight * entropy_val_loss
if not math.isnan(behavior_logit_val_loss):
if math.isnan(monitor):
monitor = behavior_logit_weight * behavior_logit_val_loss
else:
monitor = monitor + behavior_logit_weight * behavior_logit_val_loss
if not math.isnan(behavior_entropy_val_loss):
if math.isnan(monitor):
monitor = behavior_entropy_weight * behavior_entropy_val_loss
else:
monitor = monitor + behavior_entropy_weight * behavior_entropy_val_loss
if best_metric is None or monitor < best_metric:
best_metric = monitor
best_val_loss = val_loss
if teacher_sequences and not math.isnan(behavior_val_loss):
best_behavior_loss = behavior_val_loss
if not math.isnan(hidden_val_loss):
best_hidden_loss = hidden_val_loss
if not math.isnan(behavior_hidden_val_loss):
best_behavior_hidden_loss = behavior_hidden_val_loss
if not math.isnan(logit_val_loss):
best_logit_loss = logit_val_loss
if not math.isnan(behavior_logit_val_loss):
best_behavior_logit_loss = behavior_logit_val_loss
if not math.isnan(entropy_val_loss):
best_entropy_loss = entropy_val_loss
if not math.isnan(behavior_entropy_val_loss):
best_behavior_entropy_loss = behavior_entropy_val_loss
best_state = {
name: (
wrapper.U.detach().cpu().clone(),
wrapper.V.detach().cpu().clone(),
)
for name, wrapper in wrapped.items()
}
model.eval()
if original_use_cache is not None:
model.config.use_cache = original_use_cache
if best_state is not None:
for name, wrapper in wrapped.items():
best_u, best_v = best_state[name]
wrapper.U.data.copy_(best_u.to(wrapper.U.device))
wrapper.V.data.copy_(best_v.to(wrapper.V.device))
for name, wrapper in wrapped.items():
result.quantized_params[name].lr_U = wrapper.U.detach().cpu().to(torch.float16)
result.quantized_params[name].lr_V = wrapper.V.detach().cpu().to(torch.float16)
stats = {
"steps": n_steps,
"learning_rate": lr,
"batch_size": batch_size,
"max_seq_len": max_seq_len,
"n_wrapped_modules": len(wrapped),
"n_trainable_params": int(sum(p.numel() for p in trainable_params)),
"initial_loss": losses[0],
"final_loss": losses[-1],
"best_loss": min(losses),
"best_val_loss": best_val_loss,
"best_monitor": best_metric,
"behavior_weight": behavior_weight,
"distill_weight": distill_weight,
"behavior_hidden_weight": behavior_hidden_weight,
"logit_distill_weight": logit_distill_weight,
"behavior_logit_weight": behavior_logit_weight,
"entropy_distill_weight": entropy_distill_weight,
"behavior_entropy_weight": behavior_entropy_weight,
"logit_distill_temperature": logit_distill_temperature,
"n_behavior_sequences": len(teacher_sequences),
"initial_behavior_loss": behavior_losses[0] if behavior_losses else None,
"final_behavior_loss": behavior_losses[-1] if behavior_losses else None,
"best_behavior_loss": best_behavior_loss,
"initial_hidden_loss": hidden_losses[0] if hidden_losses else None,
"final_hidden_loss": hidden_losses[-1] if hidden_losses else None,
"best_hidden_loss": best_hidden_loss,
"initial_logit_loss": logit_losses[0] if logit_losses else None,
"final_logit_loss": logit_losses[-1] if logit_losses else None,
"best_logit_loss": best_logit_loss,
"initial_entropy_loss": entropy_losses[0] if entropy_losses else None,
"final_entropy_loss": entropy_losses[-1] if entropy_losses else None,
"best_entropy_loss": best_entropy_loss,
"initial_behavior_hidden_loss": behavior_hidden_losses[0]
if behavior_hidden_losses
else None,
"final_behavior_hidden_loss": behavior_hidden_losses[-1]
if behavior_hidden_losses
else None,
"best_behavior_hidden_loss": best_behavior_hidden_loss,
"initial_behavior_logit_loss": behavior_logit_losses[0]
if behavior_logit_losses
else None,
"final_behavior_logit_loss": behavior_logit_losses[-1]
if behavior_logit_losses
else None,
"best_behavior_logit_loss": best_behavior_logit_loss,
"initial_behavior_entropy_loss": behavior_entropy_losses[0]
if behavior_entropy_losses
else None,
"final_behavior_entropy_loss": behavior_entropy_losses[-1]
if behavior_entropy_losses
else None,
"best_behavior_entropy_loss": best_behavior_entropy_loss,
}
result.calibration_tune = stats
return stats
def _assign_practical_actions(
candidates: list[RoleAwareModulePolicy],
config: SmallModelQuantizationConfig,
) -> None:
for policy in candidates:
quantize = policy.role in {"attention_inputs", "mlp_inputs"}
if policy.role == "attention_output" and config.quantize_attention_output:
quantize = True
if policy.role == "mlp_output" and config.quantize_mlp_output:
quantize = True
if quantize:
policy.action = "groupwise_ternary"
policy.reason = "practical_role_policy"
else:
policy.action = "fp16"
policy.reason = "protected_role_policy"
def _assign_budgeted_actions(
candidates: list[RoleAwareModulePolicy],
config: SmallModelQuantizationConfig,
total_model_params: int,
) -> float:
predicted_bits = 16.0 * total_model_params
target_bits = float(config.target_average_bits) * total_model_params
ranked = sorted(
candidates,
key=lambda policy: (
policy.sensitivity_score
/ max(16.0 - policy.bits_if_quantized, 1e-6),
policy.sensitivity_score,
policy.role_priority,
),
)
for policy in ranked:
if predicted_bits <= target_bits:
policy.action = "fp16"
policy.reason = "budget_not_needed"
continue
policy.action = "groupwise_ternary"
policy.reason = "budget_allocation"
predicted_bits -= (16.0 - policy.bits_if_quantized) * policy.num_params
for policy in ranked:
if policy.action == "fp16" and policy.reason == "unallocated":
policy.reason = "protected_by_budget"
return predicted_bits
def _assign_salient_fractions(
candidates: list[RoleAwareModulePolicy],
config: SmallModelQuantizationConfig,
) -> None:
selected = [policy for policy in candidates if policy.action == "groupwise_ternary"]
if not selected:
return
if not config.adaptive_salient:
for policy in selected:
policy.salient_fraction = config.salient_fraction
policy.bits_if_quantized = _estimate_effective_bits(
policy,
config.group_size,
policy.salient_fraction,
)
return
scores = [policy.sensitivity_score for policy in selected]
min_score = min(scores)
max_score = max(scores)
denom = max(max_score - min_score, 1e-8)
for policy in selected:
normalized = (policy.sensitivity_score - min_score) / denom
fraction = config.min_salient_fraction + normalized * (
config.max_salient_fraction - config.min_salient_fraction
)
if policy.role in {"attention_output", "mlp_output"}:
fraction = max(fraction, config.max_salient_fraction)
policy.salient_fraction = fraction
policy.bits_if_quantized = _estimate_effective_bits(
policy,
config.group_size,
fraction,
)
def _estimate_residual_spectrum(
residual: torch.Tensor,
max_rank: int,
) -> torch.Tensor:
if max_rank <= 0:
return torch.zeros(0, dtype=torch.float32)
limit = min(max_rank, residual.shape[0], residual.shape[1])
if limit <= 0:
return torch.zeros(0, dtype=torch.float32)
singular_values = torch.linalg.svdvals(residual.float().cpu())
return singular_values[:limit].detach().cpu()
def _assign_adaptive_low_rank(
candidates: list[RoleAwareModulePolicy],
config: SmallModelQuantizationConfig,
total_model_params: int,
residual_spectra: dict[str, torch.Tensor],
) -> None:
selected = [policy for policy in candidates if policy.action == "groupwise_ternary"]
if not selected:
return
for policy in selected:
policy.low_rank_rank = 0
policy.bits_if_quantized = _estimate_effective_bits(
policy,
config.group_size,
policy.salient_fraction,
)
base_bits = _predict_total_bits(candidates, total_model_params)
uniform_extra_bits = sum(
16.0 * config.low_rank_rank * (policy.out_features + policy.in_features)
for policy in selected
)
if config.low_rank_target_average_bits is not None:
target_bits = float(config.low_rank_target_average_bits) * total_model_params
elif config.target_average_bits is not None:
target_bits = float(config.target_average_bits) * total_model_params
else:
target_bits = base_bits + uniform_extra_bits
remaining_bits = max(0.0, target_bits - base_bits)
chunk_rank = max(1, min(config.low_rank_chunk_rank, config.low_rank_rank))
current_rank = {policy.name: 0 for policy in selected}
max_rank = {
policy.name: min(config.low_rank_rank, policy.out_features, policy.in_features)
for policy in selected
}
while remaining_bits > 0:
best_policy = None
best_rank_step = 0
best_score = 0.0
best_cost = 0.0
for policy in selected:
rank_now = current_rank[policy.name]
rank_limit = max_rank[policy.name]
if rank_now >= rank_limit:
continue
step = min(chunk_rank, rank_limit - rank_now)
cost = 16.0 * step * (policy.out_features + policy.in_features)
if cost > remaining_bits + 1e-6:
continue
spectrum = residual_spectra.get(policy.name)
if spectrum is None or rank_now >= spectrum.numel():
continue
gain = float((spectrum[rank_now : rank_now + step] ** 2).sum().item())
if gain <= 0.0:
continue
weighted_gain = gain * max(policy.activation_rms, 1e-6)
score = weighted_gain / max(cost, 1e-6)
if score > best_score:
best_score = score
best_policy = policy
best_rank_step = step
best_cost = cost
if best_policy is None or best_rank_step <= 0:
break
current_rank[best_policy.name] += best_rank_step
best_policy.low_rank_rank = current_rank[best_policy.name]
best_policy.bits_if_quantized = _estimate_effective_bits(
best_policy,
config.group_size,
best_policy.salient_fraction,
)
remaining_bits -= best_cost
def _estimate_effective_bits(
policy: RoleAwareModulePolicy,
group_size: int,
salient_fraction: float,
) -> float:
if policy.action != "groupwise_ternary":
return 16.0
num_params = policy.num_params
out_features = policy.out_features
in_features = policy.in_features
n_groups = max(1, math.ceil(in_features / max(group_size, 1)))
code_bits = 2 * num_params
group_param_bits = 16 * 2 * out_features * n_groups
sparse_bits = int(max(1, salient_fraction * num_params)) * (32 + 16)
low_rank_bits = 16 * policy.low_rank_rank * (out_features + in_features)
return (code_bits + group_param_bits + sparse_bits + low_rank_bits) / max(
num_params, 1
)
def _predict_total_bits(
candidates: list[RoleAwareModulePolicy],
total_model_params: int,
) -> float:
predicted_bits = 16.0 * total_model_params
for policy in candidates:
if policy.action == "groupwise_ternary":
predicted_bits -= 16.0 * policy.num_params
predicted_bits += policy.bits_if_quantized * policy.num_params
return predicted_bits
def _collect_role_modules(
layer: nn.Module,
layer_prefix: str,
base_config: QuantizationConfig,
) -> dict[str, nn.Module]:
modules = {}
for name, module in layer.named_modules():
full_name = f"{layer_prefix}.{name}" if name else layer_prefix
if not (_is_linear_layer(module) and _should_quantize(full_name, base_config)):
continue
if _classify_module(full_name) is not None:
modules[full_name] = module
return modules
def _capture_activations(
model: nn.Module,
calibration_data: torch.Tensor,
modules: dict[str, nn.Module],
batch_size: int,
) -> dict[str, Optional[torch.Tensor]]:
captures: dict[str, ActivationCapture] = {}
for name, module in modules.items():
capture = ActivationCapture()
capture.register(module)
captures[name] = capture
_run_model_forward(model, calibration_data, batch_size=batch_size)
outputs: dict[str, Optional[torch.Tensor]] = {}
for name, capture in captures.items():
outputs[name] = capture.get_activations()
capture.remove()
return outputs
def _capture_module_io(
model: nn.Module,
calibration_data: torch.Tensor,
modules: dict[str, nn.Module],
batch_size: int,
) -> dict[str, dict[str, Optional[torch.Tensor]]]:
captures: dict[str, ModuleIOCapture] = {}
for name, module in modules.items():
capture = ModuleIOCapture()
capture.register(module)
captures[name] = capture
_run_model_forward(model, calibration_data, batch_size=batch_size)
outputs: dict[str, dict[str, Optional[torch.Tensor]]] = {}
for name, capture in captures.items():
outputs[name] = {
"inputs": capture.get_inputs(),
"outputs": capture.get_outputs(),
}
capture.remove()
return outputs
def _resolve_parent_module(model: nn.Module, name: str) -> tuple[nn.Module, str]:
parent = model
parts = name.split(".")
for part in parts[:-1]:
parent = getattr(parent, part)
return parent, parts[-1]
def _activation_rms(activations: Optional[torch.Tensor]) -> float:
if activations is None or activations.numel() == 0:
return 1.0
return activations.float().square().mean().sqrt().item()
def _build_dependency_groups_for_plan(
layer: nn.Module,
layer_prefix: str,
layer_idx: int,
plan: RoleAwareQuantizationPlan,
config: SmallModelQuantizationConfig,
) -> tuple[list[dict[str, nn.Module]], list[str]]:
grouped_modules = {role: {} for role in ROLE_ORDER}
protected: list[str] = []
for name, module in layer.named_modules():
full_name = f"{layer_prefix}.{name}" if name else layer_prefix
if not (_is_linear_layer(module) and _should_quantize(full_name, config.base_config)):
continue
role = _classify_module(full_name)
if role is None:
continue
policy = plan.policies.get(full_name)
if policy is None or policy.action != "groupwise_ternary":
protected.append(full_name)
continue
grouped_modules[role][full_name] = module
ordered_groups = [grouped_modules[role] for role in ROLE_ORDER if grouped_modules[role]]
return ordered_groups, protected
def _classify_module(full_name: str) -> Optional[str]:
target_name = full_name.split(".")[-1]
if target_name in {"q_proj", "k_proj", "v_proj", "query_key_value", "c_attn"}:
return "attention_inputs"
if target_name in {"gate_proj", "up_proj", "w1", "w3", "dense_h_to_4h", "c_fc"}:
return "mlp_inputs"
if target_name in {"o_proj", "dense"}:
return "attention_output"
if target_name in {"down_proj", "w2", "dense_4h_to_h"}:
return "mlp_output"
if target_name == "c_proj":
if ".attn." in full_name or ".self_attn." in full_name:
return "attention_output"
if ".mlp." in full_name:
return "mlp_output"
return None