AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
Layer-wise post-training ternary quantization pipeline.
Uses full model forward passes with hooks to capture activations at each
linear layer. This approach is robust across all HuggingFace architectures
(handles position embeddings, rotary embeddings, etc. automatically).
Quantization proceeds layer-by-layer through the decoder stack:
1. Run full model forward, capture activations at current layer's linears
2. Quantize those linears using captured activations
3. Replace weights with dequantized ternary approximation
4. Repeat for next layer (subsequent layers see quantized prior layers)
"""
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from dataclasses import dataclass, field
from ternary_quant.quantizer import TernaryQuantizer, compute_quantization_error
@dataclass
class QuantizationConfig:
"""Configuration for the quantization pipeline."""
n_iter: int = 10
use_activation_aware: bool = True
block_size: int = 0
importance_threshold_scale: float = 0.0
n_samples: int = 128
seq_len: int = 2048
dataset: str = "wikitext"
dataset_config: str = "wikitext-2-raw-v1"
seed: int = 42
# Layers to skip (keep in original precision)
skip_modules: list = field(
default_factory=lambda: ["lm_head"]
)
# Only quantize modules matching these name patterns
target_modules: list = field(
default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",
"w1", "w2", "w3",
"c_attn", "c_proj", "c_fc",
]
)
@dataclass
class QuantizationResult:
"""Result of quantizing a full model."""
ternary_params: dict # name -> TernaryParameter
config: QuantizationConfig
model_config: AutoConfig
model_name: str
stats: dict # per-layer quantization statistics
class ActivationCapture:
"""Hook-based activation capture for a single linear layer."""
def __init__(self):
self.activations = []
self.hook = None
def register(self, module: nn.Module):
self.hook = module.register_forward_hook(self._hook_fn)
def _hook_fn(self, module, input, output):
inp = input[0].detach()
if inp.dim() == 3:
inp = inp.reshape(-1, inp.shape[-1])
self.activations.append(inp.cpu())
def get_activations(self) -> torch.Tensor:
if not self.activations:
return None
return torch.cat(self.activations, dim=0)
def remove(self):
if self.hook is not None:
self.hook.remove()
self.hook = None
self.activations = []
@dataclass
class _LinearLikeSpec:
"""Resolved view of a linear-like module or wrapper."""
target_module: nn.Module
target_path: tuple[str, ...]
transpose_weight: bool
_WRAPPED_LINEAR_ATTRS = (
"linear",
"base_layer",
)
def _as_weight_tensor(module: nn.Module) -> Optional[torch.Tensor]:
weight = getattr(module, "weight", None)
if isinstance(weight, (torch.Tensor, nn.Parameter)) and weight.ndim == 2:
return weight
return None
def _infer_direct_linear_spec(module: nn.Module) -> Optional[_LinearLikeSpec]:
"""Recognize direct linear-like modules and their weight layout."""
weight = _as_weight_tensor(module)
if weight is None:
return None
cls_name = type(module).__name__
if isinstance(module, nn.Linear):
return _LinearLikeSpec(module, (), False)
if cls_name == "Conv1D":
return _LinearLikeSpec(module, (), True)
for in_attr, out_attr in (("in_features", "out_features"), ("input_size", "output_size")):
in_features = getattr(module, in_attr, None)
out_features = getattr(module, out_attr, None)
if not isinstance(in_features, int) or not isinstance(out_features, int):
continue
if tuple(weight.shape) == (out_features, in_features):
return _LinearLikeSpec(module, (), False)
if tuple(weight.shape) == (in_features, out_features):
return _LinearLikeSpec(module, (), True)
# Fall back for direct custom linear classes that expose a conventional
# 2D weight matrix but are not subclasses of nn.Linear.
if "Linear" in cls_name:
return _LinearLikeSpec(module, (), False)
return None
def _resolve_linear_like_spec(
module: nn.Module,
_seen: Optional[set[int]] = None,
) -> Optional[_LinearLikeSpec]:
"""Resolve direct or wrapped linear-like modules to a canonical spec."""
if _seen is None:
_seen = set()
if id(module) in _seen:
return None
_seen.add(id(module))
direct = _infer_direct_linear_spec(module)
if direct is not None:
return direct
for attr in _WRAPPED_LINEAR_ATTRS:
child = getattr(module, attr, None)
if not isinstance(child, nn.Module):
continue
child_spec = _resolve_linear_like_spec(child, _seen)
if child_spec is not None:
return _LinearLikeSpec(
target_module=child_spec.target_module,
target_path=(attr, *child_spec.target_path),
transpose_weight=child_spec.transpose_weight,
)
return None
def _is_linear_layer(module: nn.Module) -> bool:
"""Check if a module is direct linear math or a thin wrapper around it."""
return _resolve_linear_like_spec(module) is not None
def _get_weight(module: nn.Module) -> torch.Tensor:
"""Get weight matrix in [out_features, in_features] format."""
spec = _resolve_linear_like_spec(module)
if spec is None:
raise TypeError(f"Unsupported linear-like module: {type(module).__name__}")
weight = spec.target_module.weight.data
if spec.transpose_weight:
return weight.T.contiguous()
return weight
def _set_weight(module: nn.Module, weight: torch.Tensor):
"""Set weight matrix, handling Conv1D transpose."""
spec = _resolve_linear_like_spec(module)
if spec is None:
raise TypeError(f"Unsupported linear-like module: {type(module).__name__}")
if spec.transpose_weight:
spec.target_module.weight.data = weight.T.contiguous()
else:
spec.target_module.weight.data = weight
def _get_bias(module: nn.Module) -> Optional[torch.Tensor]:
"""Get bias tensor from the resolved linear target if present."""
spec = _resolve_linear_like_spec(module)
if spec is None:
raise TypeError(f"Unsupported linear-like module: {type(module).__name__}")
bias = getattr(spec.target_module, "bias", None)
if isinstance(bias, (torch.Tensor, nn.Parameter)):
return bias.data
return None
def _install_linear_replacement(module: nn.Module, replacement: nn.Module) -> nn.Module:
"""Install a quantized replacement while preserving wrapper modules when possible."""
spec = _resolve_linear_like_spec(module)
if spec is None:
raise TypeError(f"Unsupported linear-like module: {type(module).__name__}")
if not spec.target_path:
return replacement
parent = module
for attr in spec.target_path[:-1]:
parent = getattr(parent, attr)
setattr(parent, spec.target_path[-1], replacement)
return module
def _should_quantize(name: str, config: QuantizationConfig) -> bool:
"""Check if a module should be quantized based on config."""
for skip in config.skip_modules:
if skip in name:
return False
module_name = name.split(".")[-1]
return module_name in config.target_modules
def _get_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList, str]:
"""Find the main decoder layer stack."""
candidates = [
"model.layers", # LLaMA, Mistral, Qwen, Gemma, SmolLM
"transformer.h", # GPT-2, GPT-Neo
"transformer.layers", # Some models
"gpt_neox.layers", # GPT-NeoX
"model.decoder.layers", # OPT
]
for path in candidates:
obj = model
try:
for attr in path.split("."):
obj = getattr(obj, attr)
if isinstance(obj, nn.ModuleList) and len(obj) > 0:
return obj, path
except AttributeError:
continue
raise ValueError(
"Could not find decoder layers. Supported architectures: "
"LLaMA, Mistral, Qwen, Gemma, GPT-2, GPT-NeoX, OPT, Phi, Falcon."
)
def _run_model_forward(model: nn.Module, input_ids: torch.Tensor, batch_size: int = 4):
"""Run full model forward in batches (for hook-based activation capture)."""
with torch.no_grad():
for i in range(0, input_ids.shape[0], batch_size):
batch = input_ids[i : i + batch_size]
model(batch)
def quantize_model(
model_name_or_path: str,
config: Optional[QuantizationConfig] = None,
device: str = "auto",
dtype: torch.dtype = torch.float16,
calibration_data: Optional[torch.Tensor] = None,
) -> QuantizationResult:
"""
Quantize a HuggingFace model to ternary weights.
Uses full model forward passes with hooks to capture activations.
Processes decoder layers sequentially: after quantizing each layer's
linear modules, replaces their weights with dequantized ternary
approximations so subsequent layers see quantized inputs.
Args:
model_name_or_path: HuggingFace model ID or local path
config: Quantization configuration
device: Device to run calibration on ("auto", "cuda", "cpu")
dtype: Model dtype for loading
calibration_data: Pre-tokenized calibration data [n_samples, seq_len].
Returns:
QuantizationResult with all quantized parameters and statistics.
"""
if config is None:
config = QuantizationConfig()
print(f"Loading model: {model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model_config = AutoConfig.from_pretrained(model_name_or_path)
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map=device if device != "cpu" else None,
low_cpu_mem_usage=True,
)
if device == "cpu":
model = model.to(device)
model.eval()
# Load calibration data
if calibration_data is None:
print(f"Loading calibration data from {config.dataset}...")
from ternary_quant.data import get_calibration_data
calibration_data = get_calibration_data(
model_name_or_path,
dataset_name=config.dataset,
dataset_config=config.dataset_config,
n_samples=config.n_samples,
seq_len=config.seq_len,
seed=config.seed,
tokenizer=tokenizer,
)
calibration_data = calibration_data.to(model.device)
# Find decoder layers
decoder_layers, layer_path = _get_decoder_layers(model)
n_layers = len(decoder_layers)
print(f"Found {n_layers} decoder layers at '{layer_path}'")
# Build map of layers to their target linear modules
layer_linear_map = {} # layer_idx -> {full_name: module}
for layer_idx in range(n_layers):
layer = decoder_layers[layer_idx]
layer_prefix = f"{layer_path}.{layer_idx}"
linears = {}
for name, module in layer.named_modules():
full_name = f"{layer_prefix}.{name}" if name else layer_prefix
if _is_linear_layer(module) and _should_quantize(full_name, config):
linears[full_name] = module
layer_linear_map[layer_idx] = linears
total_linears = sum(len(v) for v in layer_linear_map.values())
print(f"Found {total_linears} linear layers to quantize")
quantizer = TernaryQuantizer(
n_iter=config.n_iter,
use_activation_aware=config.use_activation_aware,
block_size=config.block_size,
importance_threshold_scale=config.importance_threshold_scale,
)
ternary_params = {}
stats = {}
# --- Process each decoder layer ---
# For each layer:
# 1. Register hooks on its linear modules
# 2. Run full model forward to capture activations
# 3. Quantize each linear using captured activations
# 4. Replace weights with dequantized versions
# This propagates quantization effects: later layers see quantized earlier layers.
print("Quantizing layers...")
for layer_idx in tqdm(range(n_layers), desc="Quantizing layers"):
linears = layer_linear_map[layer_idx]
if not linears:
continue
# Register activation captures
captures = {}
for name, module in linears.items():
cap = ActivationCapture()
cap.register(module)
captures[name] = cap
# Run full model forward to capture activations at this layer
_run_model_forward(model, calibration_data)
# Quantize each linear in this layer
for name, module in linears.items():
cap = captures[name]
acts = cap.get_activations()
cap.remove()
weight = _get_weight(module)
tp = quantizer.quantize(weight, activations=acts)
ternary_params[name] = tp
error = compute_quantization_error(weight, tp)
stats[name] = error
# Replace weight with dequantized version for subsequent layers
dequant = tp.dequantize().to(weight.dtype).to(weight.device)
_set_weight(module, dequant)
del acts
# --- Quantize any remaining target linears outside the decoder stack ---
for name, module in model.named_modules():
if (
_is_linear_layer(module)
and _should_quantize(name, config)
and name not in ternary_params
and not any(name.startswith(f"{layer_path}.{i}") for i in range(n_layers))
):
weight = _get_weight(module)
tp = quantizer.quantize(weight, activations=None)
ternary_params[name] = tp
stats[name] = compute_quantization_error(weight, tp)
if ternary_params:
_print_summary(ternary_params, stats, model)
else:
print("WARNING: No layers were quantized!")
return QuantizationResult(
ternary_params=ternary_params,
config=config,
model_config=model_config,
model_name=model_name_or_path,
stats=stats,
)
def _print_summary(
ternary_params: dict,
stats: dict,
model: nn.Module,
):
"""Print quantization summary statistics."""
total_params = sum(tp.num_params for tp in ternary_params.values())
total_model_params = sum(p.numel() for p in model.parameters())
avg_sparsity = sum(s["sparsity"] for s in stats.values()) / len(stats)
avg_rel_error = sum(s["relative_error"] for s in stats.values()) / len(stats)
avg_bits = sum(s["bits_per_param"] for s in stats.values()) / len(stats)
fp16_bytes = total_params * 2
ternary_bytes = total_params * 2 / 8
n_rows = sum(tp.original_shape[0] for tp in ternary_params.values())
ternary_bytes += n_rows * 4
print("\n" + "=" * 60)
print("QUANTIZATION SUMMARY")
print("=" * 60)
print(f"Quantized parameters: {total_params:,} / {total_model_params:,} "
f"({100 * total_params / total_model_params:.1f}%)")
print(f"Quantized layers: {len(ternary_params)}")
print(f"Average bits/param: {avg_bits:.2f}")
print(f"Average sparsity: {avg_sparsity:.1%}")
print(f"Average rel. error: {avg_rel_error:.4f}")
print(f"FP16 size: {fp16_bytes / 1e9:.2f} GB")
print(f"Ternary size: {ternary_bytes / 1e9:.2f} GB")
print(f"Compression ratio: {fp16_bytes / ternary_bytes:.1f}x")
print("=" * 60)