from __future__ import annotations from dataclasses import dataclass, field, asdict from typing import Dict, List @dataclass class CalibrationConfig: """Settings that govern how calibration samples are collected.""" dataset_name: str = "mit-han-lab/pile-val-backup" dataset_split: str = "validation" text_column: str = "text" sample_count: int = 128 max_sequence_length: int = 512 batch_size: int = 8 shuffle: bool = False streaming: bool = False seed: int = 0 def to_dict(self) -> Dict[str, int | str | bool]: return asdict(self) @dataclass class QuantizationConfig: """Configuration for quantization pipelines (AWQ, SmoothQuant, etc.).""" enabled: bool = True method: str = "awq" weight_bits: int = 4 activation_bits: int = 8 group_size: int = 128 per_channel: bool = True calibration: CalibrationConfig = field(default_factory=CalibrationConfig) activation_clip: float | None = None epsilon: float = 1e-5 alpha: float = 0.5 # module name filters (glob patterns allowed), applied on qualified names from named_modules() include_modules: List[str] = field(default_factory=list) exclude_modules: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, int | float | bool | Dict[str, int | str | bool]]: return { "enabled": self.enabled, "method": self.method, "weight_bits": self.weight_bits, "activation_bits": self.activation_bits, "group_size": self.group_size, "per_channel": self.per_channel, "calibration": self.calibration.to_dict(), "activation_clip": self.activation_clip, "epsilon": self.epsilon, "alpha": self.alpha, "include_modules": list(self.include_modules), "exclude_modules": list(self.exclude_modules), }