rayf-07's picture
Upload Ouro-2.6B_smoothquant_W8A8 with bundled source code
b144856 verified
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import Dict, List
@dataclass
class CalibrationConfig:
"""Settings that govern how calibration samples are collected."""
dataset_name: str = "mit-han-lab/pile-val-backup"
dataset_split: str = "validation"
text_column: str = "text"
sample_count: int = 128
max_sequence_length: int = 512
batch_size: int = 8
shuffle: bool = False
streaming: bool = False
seed: int = 0
def to_dict(self) -> Dict[str, int | str | bool]:
return asdict(self)
@dataclass
class QuantizationConfig:
"""Configuration for quantization pipelines (AWQ, SmoothQuant, etc.)."""
enabled: bool = True
method: str = "awq"
weight_bits: int = 4
activation_bits: int = 8
group_size: int = 128
per_channel: bool = True
calibration: CalibrationConfig = field(default_factory=CalibrationConfig)
activation_clip: float | None = None
epsilon: float = 1e-5
alpha: float = 0.5
# module name filters (glob patterns allowed), applied on qualified names from named_modules()
include_modules: List[str] = field(default_factory=list)
exclude_modules: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, int | float | bool | Dict[str, int | str | bool]]:
return {
"enabled": self.enabled,
"method": self.method,
"weight_bits": self.weight_bits,
"activation_bits": self.activation_bits,
"group_size": self.group_size,
"per_channel": self.per_channel,
"calibration": self.calibration.to_dict(),
"activation_clip": self.activation_clip,
"epsilon": self.epsilon,
"alpha": self.alpha,
"include_modules": list(self.include_modules),
"exclude_modules": list(self.exclude_modules),
}