Spaces:
Running
Running
| """ | |
| Layer-wise post-training ternary quantization pipeline. | |
| Uses full model forward passes with hooks to capture activations at each | |
| linear layer. This approach is robust across all HuggingFace architectures | |
| (handles position embeddings, rotary embeddings, etc. automatically). | |
| Quantization proceeds layer-by-layer through the decoder stack: | |
| 1. Run full model forward, capture activations at current layer's linears | |
| 2. Quantize those linears using captured activations | |
| 3. Replace weights with dequantized ternary approximation | |
| 4. Repeat for next layer (subsequent layers see quantized prior layers) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from typing import Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | |
| from dataclasses import dataclass, field | |
| from ternary_quant.quantizer import TernaryQuantizer, compute_quantization_error | |
| class QuantizationConfig: | |
| """Configuration for the quantization pipeline.""" | |
| n_iter: int = 10 | |
| use_activation_aware: bool = True | |
| block_size: int = 0 | |
| importance_threshold_scale: float = 0.0 | |
| n_samples: int = 128 | |
| seq_len: int = 2048 | |
| dataset: str = "wikitext" | |
| dataset_config: str = "wikitext-2-raw-v1" | |
| seed: int = 42 | |
| # Layers to skip (keep in original precision) | |
| skip_modules: list = field( | |
| default_factory=lambda: ["lm_head"] | |
| ) | |
| # Only quantize modules matching these name patterns | |
| target_modules: list = field( | |
| default_factory=lambda: [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| "query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h", | |
| "w1", "w2", "w3", | |
| "c_attn", "c_proj", "c_fc", | |
| ] | |
| ) | |
| class QuantizationResult: | |
| """Result of quantizing a full model.""" | |
| ternary_params: dict # name -> TernaryParameter | |
| config: QuantizationConfig | |
| model_config: AutoConfig | |
| model_name: str | |
| stats: dict # per-layer quantization statistics | |
| class ActivationCapture: | |
| """Hook-based activation capture for a single linear layer.""" | |
| def __init__(self): | |
| self.activations = [] | |
| self.hook = None | |
| def register(self, module: nn.Module): | |
| self.hook = module.register_forward_hook(self._hook_fn) | |
| def _hook_fn(self, module, input, output): | |
| inp = input[0].detach() | |
| if inp.dim() == 3: | |
| inp = inp.reshape(-1, inp.shape[-1]) | |
| self.activations.append(inp.cpu()) | |
| def get_activations(self) -> torch.Tensor: | |
| if not self.activations: | |
| return None | |
| return torch.cat(self.activations, dim=0) | |
| def remove(self): | |
| if self.hook is not None: | |
| self.hook.remove() | |
| self.hook = None | |
| self.activations = [] | |
| class _LinearLikeSpec: | |
| """Resolved view of a linear-like module or wrapper.""" | |
| target_module: nn.Module | |
| target_path: tuple[str, ...] | |
| transpose_weight: bool | |
| _WRAPPED_LINEAR_ATTRS = ( | |
| "linear", | |
| "base_layer", | |
| ) | |
| def _as_weight_tensor(module: nn.Module) -> Optional[torch.Tensor]: | |
| weight = getattr(module, "weight", None) | |
| if isinstance(weight, (torch.Tensor, nn.Parameter)) and weight.ndim == 2: | |
| return weight | |
| return None | |
| def _infer_direct_linear_spec(module: nn.Module) -> Optional[_LinearLikeSpec]: | |
| """Recognize direct linear-like modules and their weight layout.""" | |
| weight = _as_weight_tensor(module) | |
| if weight is None: | |
| return None | |
| cls_name = type(module).__name__ | |
| if isinstance(module, nn.Linear): | |
| return _LinearLikeSpec(module, (), False) | |
| if cls_name == "Conv1D": | |
| return _LinearLikeSpec(module, (), True) | |
| for in_attr, out_attr in (("in_features", "out_features"), ("input_size", "output_size")): | |
| in_features = getattr(module, in_attr, None) | |
| out_features = getattr(module, out_attr, None) | |
| if not isinstance(in_features, int) or not isinstance(out_features, int): | |
| continue | |
| if tuple(weight.shape) == (out_features, in_features): | |
| return _LinearLikeSpec(module, (), False) | |
| if tuple(weight.shape) == (in_features, out_features): | |
| return _LinearLikeSpec(module, (), True) | |
| # Fall back for direct custom linear classes that expose a conventional | |
| # 2D weight matrix but are not subclasses of nn.Linear. | |
| if "Linear" in cls_name: | |
| return _LinearLikeSpec(module, (), False) | |
| return None | |
| def _resolve_linear_like_spec( | |
| module: nn.Module, | |
| _seen: Optional[set[int]] = None, | |
| ) -> Optional[_LinearLikeSpec]: | |
| """Resolve direct or wrapped linear-like modules to a canonical spec.""" | |
| if _seen is None: | |
| _seen = set() | |
| if id(module) in _seen: | |
| return None | |
| _seen.add(id(module)) | |
| direct = _infer_direct_linear_spec(module) | |
| if direct is not None: | |
| return direct | |
| for attr in _WRAPPED_LINEAR_ATTRS: | |
| child = getattr(module, attr, None) | |
| if not isinstance(child, nn.Module): | |
| continue | |
| child_spec = _resolve_linear_like_spec(child, _seen) | |
| if child_spec is not None: | |
| return _LinearLikeSpec( | |
| target_module=child_spec.target_module, | |
| target_path=(attr, *child_spec.target_path), | |
| transpose_weight=child_spec.transpose_weight, | |
| ) | |
| return None | |
| def _is_linear_layer(module: nn.Module) -> bool: | |
| """Check if a module is direct linear math or a thin wrapper around it.""" | |
| return _resolve_linear_like_spec(module) is not None | |
| def _get_weight(module: nn.Module) -> torch.Tensor: | |
| """Get weight matrix in [out_features, in_features] format.""" | |
| spec = _resolve_linear_like_spec(module) | |
| if spec is None: | |
| raise TypeError(f"Unsupported linear-like module: {type(module).__name__}") | |
| weight = spec.target_module.weight.data | |
| if spec.transpose_weight: | |
| return weight.T.contiguous() | |
| return weight | |
| def _set_weight(module: nn.Module, weight: torch.Tensor): | |
| """Set weight matrix, handling Conv1D transpose.""" | |
| spec = _resolve_linear_like_spec(module) | |
| if spec is None: | |
| raise TypeError(f"Unsupported linear-like module: {type(module).__name__}") | |
| if spec.transpose_weight: | |
| spec.target_module.weight.data = weight.T.contiguous() | |
| else: | |
| spec.target_module.weight.data = weight | |
| def _get_bias(module: nn.Module) -> Optional[torch.Tensor]: | |
| """Get bias tensor from the resolved linear target if present.""" | |
| spec = _resolve_linear_like_spec(module) | |
| if spec is None: | |
| raise TypeError(f"Unsupported linear-like module: {type(module).__name__}") | |
| bias = getattr(spec.target_module, "bias", None) | |
| if isinstance(bias, (torch.Tensor, nn.Parameter)): | |
| return bias.data | |
| return None | |
| def _install_linear_replacement(module: nn.Module, replacement: nn.Module) -> nn.Module: | |
| """Install a quantized replacement while preserving wrapper modules when possible.""" | |
| spec = _resolve_linear_like_spec(module) | |
| if spec is None: | |
| raise TypeError(f"Unsupported linear-like module: {type(module).__name__}") | |
| if not spec.target_path: | |
| return replacement | |
| parent = module | |
| for attr in spec.target_path[:-1]: | |
| parent = getattr(parent, attr) | |
| setattr(parent, spec.target_path[-1], replacement) | |
| return module | |
| def _should_quantize(name: str, config: QuantizationConfig) -> bool: | |
| """Check if a module should be quantized based on config.""" | |
| for skip in config.skip_modules: | |
| if skip in name: | |
| return False | |
| module_name = name.split(".")[-1] | |
| return module_name in config.target_modules | |
| def _get_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList, str]: | |
| """Find the main decoder layer stack.""" | |
| candidates = [ | |
| "model.layers", # LLaMA, Mistral, Qwen, Gemma, SmolLM | |
| "transformer.h", # GPT-2, GPT-Neo | |
| "transformer.layers", # Some models | |
| "gpt_neox.layers", # GPT-NeoX | |
| "model.decoder.layers", # OPT | |
| ] | |
| for path in candidates: | |
| obj = model | |
| try: | |
| for attr in path.split("."): | |
| obj = getattr(obj, attr) | |
| if isinstance(obj, nn.ModuleList) and len(obj) > 0: | |
| return obj, path | |
| except AttributeError: | |
| continue | |
| raise ValueError( | |
| "Could not find decoder layers. Supported architectures: " | |
| "LLaMA, Mistral, Qwen, Gemma, GPT-2, GPT-NeoX, OPT, Phi, Falcon." | |
| ) | |
| def _run_model_forward(model: nn.Module, input_ids: torch.Tensor, batch_size: int = 4): | |
| """Run full model forward in batches (for hook-based activation capture).""" | |
| with torch.no_grad(): | |
| for i in range(0, input_ids.shape[0], batch_size): | |
| batch = input_ids[i : i + batch_size] | |
| model(batch) | |
| def quantize_model( | |
| model_name_or_path: str, | |
| config: Optional[QuantizationConfig] = None, | |
| device: str = "auto", | |
| dtype: torch.dtype = torch.float16, | |
| calibration_data: Optional[torch.Tensor] = None, | |
| ) -> QuantizationResult: | |
| """ | |
| Quantize a HuggingFace model to ternary weights. | |
| Uses full model forward passes with hooks to capture activations. | |
| Processes decoder layers sequentially: after quantizing each layer's | |
| linear modules, replaces their weights with dequantized ternary | |
| approximations so subsequent layers see quantized inputs. | |
| Args: | |
| model_name_or_path: HuggingFace model ID or local path | |
| config: Quantization configuration | |
| device: Device to run calibration on ("auto", "cuda", "cpu") | |
| dtype: Model dtype for loading | |
| calibration_data: Pre-tokenized calibration data [n_samples, seq_len]. | |
| Returns: | |
| QuantizationResult with all quantized parameters and statistics. | |
| """ | |
| if config is None: | |
| config = QuantizationConfig() | |
| print(f"Loading model: {model_name_or_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| model_config = AutoConfig.from_pretrained(model_name_or_path) | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=dtype, | |
| device_map=device if device != "cpu" else None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| if device == "cpu": | |
| model = model.to(device) | |
| model.eval() | |
| # Load calibration data | |
| if calibration_data is None: | |
| print(f"Loading calibration data from {config.dataset}...") | |
| from ternary_quant.data import get_calibration_data | |
| calibration_data = get_calibration_data( | |
| model_name_or_path, | |
| dataset_name=config.dataset, | |
| dataset_config=config.dataset_config, | |
| n_samples=config.n_samples, | |
| seq_len=config.seq_len, | |
| seed=config.seed, | |
| tokenizer=tokenizer, | |
| ) | |
| calibration_data = calibration_data.to(model.device) | |
| # Find decoder layers | |
| decoder_layers, layer_path = _get_decoder_layers(model) | |
| n_layers = len(decoder_layers) | |
| print(f"Found {n_layers} decoder layers at '{layer_path}'") | |
| # Build map of layers to their target linear modules | |
| layer_linear_map = {} # layer_idx -> {full_name: module} | |
| for layer_idx in range(n_layers): | |
| layer = decoder_layers[layer_idx] | |
| layer_prefix = f"{layer_path}.{layer_idx}" | |
| linears = {} | |
| for name, module in layer.named_modules(): | |
| full_name = f"{layer_prefix}.{name}" if name else layer_prefix | |
| if _is_linear_layer(module) and _should_quantize(full_name, config): | |
| linears[full_name] = module | |
| layer_linear_map[layer_idx] = linears | |
| total_linears = sum(len(v) for v in layer_linear_map.values()) | |
| print(f"Found {total_linears} linear layers to quantize") | |
| quantizer = TernaryQuantizer( | |
| n_iter=config.n_iter, | |
| use_activation_aware=config.use_activation_aware, | |
| block_size=config.block_size, | |
| importance_threshold_scale=config.importance_threshold_scale, | |
| ) | |
| ternary_params = {} | |
| stats = {} | |
| # --- Process each decoder layer --- | |
| # For each layer: | |
| # 1. Register hooks on its linear modules | |
| # 2. Run full model forward to capture activations | |
| # 3. Quantize each linear using captured activations | |
| # 4. Replace weights with dequantized versions | |
| # This propagates quantization effects: later layers see quantized earlier layers. | |
| print("Quantizing layers...") | |
| for layer_idx in tqdm(range(n_layers), desc="Quantizing layers"): | |
| linears = layer_linear_map[layer_idx] | |
| if not linears: | |
| continue | |
| # Register activation captures | |
| captures = {} | |
| for name, module in linears.items(): | |
| cap = ActivationCapture() | |
| cap.register(module) | |
| captures[name] = cap | |
| # Run full model forward to capture activations at this layer | |
| _run_model_forward(model, calibration_data) | |
| # Quantize each linear in this layer | |
| for name, module in linears.items(): | |
| cap = captures[name] | |
| acts = cap.get_activations() | |
| cap.remove() | |
| weight = _get_weight(module) | |
| tp = quantizer.quantize(weight, activations=acts) | |
| ternary_params[name] = tp | |
| error = compute_quantization_error(weight, tp) | |
| stats[name] = error | |
| # Replace weight with dequantized version for subsequent layers | |
| dequant = tp.dequantize().to(weight.dtype).to(weight.device) | |
| _set_weight(module, dequant) | |
| del acts | |
| # --- Quantize any remaining target linears outside the decoder stack --- | |
| for name, module in model.named_modules(): | |
| if ( | |
| _is_linear_layer(module) | |
| and _should_quantize(name, config) | |
| and name not in ternary_params | |
| and not any(name.startswith(f"{layer_path}.{i}") for i in range(n_layers)) | |
| ): | |
| weight = _get_weight(module) | |
| tp = quantizer.quantize(weight, activations=None) | |
| ternary_params[name] = tp | |
| stats[name] = compute_quantization_error(weight, tp) | |
| if ternary_params: | |
| _print_summary(ternary_params, stats, model) | |
| else: | |
| print("WARNING: No layers were quantized!") | |
| return QuantizationResult( | |
| ternary_params=ternary_params, | |
| config=config, | |
| model_config=model_config, | |
| model_name=model_name_or_path, | |
| stats=stats, | |
| ) | |
| def _print_summary( | |
| ternary_params: dict, | |
| stats: dict, | |
| model: nn.Module, | |
| ): | |
| """Print quantization summary statistics.""" | |
| total_params = sum(tp.num_params for tp in ternary_params.values()) | |
| total_model_params = sum(p.numel() for p in model.parameters()) | |
| avg_sparsity = sum(s["sparsity"] for s in stats.values()) / len(stats) | |
| avg_rel_error = sum(s["relative_error"] for s in stats.values()) / len(stats) | |
| avg_bits = sum(s["bits_per_param"] for s in stats.values()) / len(stats) | |
| fp16_bytes = total_params * 2 | |
| ternary_bytes = total_params * 2 / 8 | |
| n_rows = sum(tp.original_shape[0] for tp in ternary_params.values()) | |
| ternary_bytes += n_rows * 4 | |
| print("\n" + "=" * 60) | |
| print("QUANTIZATION SUMMARY") | |
| print("=" * 60) | |
| print(f"Quantized parameters: {total_params:,} / {total_model_params:,} " | |
| f"({100 * total_params / total_model_params:.1f}%)") | |
| print(f"Quantized layers: {len(ternary_params)}") | |
| print(f"Average bits/param: {avg_bits:.2f}") | |
| print(f"Average sparsity: {avg_sparsity:.1%}") | |
| print(f"Average rel. error: {avg_rel_error:.4f}") | |
| print(f"FP16 size: {fp16_bytes / 1e9:.2f} GB") | |
| print(f"Ternary size: {ternary_bytes / 1e9:.2f} GB") | |
| print(f"Compression ratio: {fp16_bytes / ternary_bytes:.1f}x") | |
| print("=" * 60) | |