| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass, field, asdict |
| | from typing import Dict, List |
| |
|
| |
|
| | @dataclass |
| | class CalibrationConfig: |
| | """Settings that govern how calibration samples are collected.""" |
| |
|
| | dataset_name: str = "mit-han-lab/pile-val-backup" |
| | dataset_split: str = "validation" |
| | text_column: str = "text" |
| | sample_count: int = 128 |
| | max_sequence_length: int = 512 |
| | batch_size: int = 8 |
| | shuffle: bool = False |
| | streaming: bool = False |
| | seed: int = 0 |
| |
|
| | def to_dict(self) -> Dict[str, int | str | bool]: |
| | return asdict(self) |
| |
|
| |
|
| | @dataclass |
| | class QuantizationConfig: |
| | """Configuration for quantization pipelines (AWQ, SmoothQuant, etc.).""" |
| |
|
| | enabled: bool = True |
| | method: str = "awq" |
| | weight_bits: int = 4 |
| | activation_bits: int = 8 |
| | group_size: int = 128 |
| | per_channel: bool = True |
| | calibration: CalibrationConfig = field(default_factory=CalibrationConfig) |
| | activation_clip: float | None = None |
| | epsilon: float = 1e-5 |
| | alpha: float = 0.5 |
| | |
| | include_modules: List[str] = field(default_factory=list) |
| | exclude_modules: List[str] = field(default_factory=list) |
| |
|
| | def to_dict(self) -> Dict[str, int | float | bool | Dict[str, int | str | bool]]: |
| | return { |
| | "enabled": self.enabled, |
| | "method": self.method, |
| | "weight_bits": self.weight_bits, |
| | "activation_bits": self.activation_bits, |
| | "group_size": self.group_size, |
| | "per_channel": self.per_channel, |
| | "calibration": self.calibration.to_dict(), |
| | "activation_clip": self.activation_clip, |
| | "epsilon": self.epsilon, |
| | "alpha": self.alpha, |
| | "include_modules": list(self.include_modules), |
| | "exclude_modules": list(self.exclude_modules), |
| | } |
| |
|