| |
| import math |
| from typing import List, Tuple |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| def _round_ste(x: torch.Tensor) -> torch.Tensor: |
| return (x.round() - x).detach() + x |
|
|
| def _clamp_ste(x: torch.Tensor, minv: float, maxv: float) -> torch.Tensor: |
| return (x.clamp(minv, maxv) - x).detach() + x |
|
|
| |
| |
| |
| class EMAMinMaxObserver_0920(nn.Module): |
| def __init__(self, momentum: float = 0.95, symmetric: bool = True): |
| super().__init__() |
| self.momentum = momentum |
| self.symmetric = symmetric |
| self.register_buffer("min_val", torch.tensor(float("+inf"))) |
| self.register_buffer("max_val", torch.tensor(float("-inf"))) |
| self.frozen = False |
|
|
| @torch.no_grad() |
| def update(self, x: torch.Tensor): |
| if self.frozen: |
| return |
| cur_min = x.min().detach() |
| cur_max = x.max().detach() |
| if torch.isinf(self.min_val): |
| self.min_val.copy_(cur_min) |
| self.max_val.copy_(cur_max) |
| else: |
| self.min_val.mul_(self.momentum).add_(cur_min * (1 - self.momentum)) |
| self.max_val.mul_(self.momentum).add_(cur_max * (1 - self.momentum)) |
|
|
| def freeze(self, flag: bool = True): |
| self.frozen = flag |
|
|
| def get_range(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.symmetric: |
| m = torch.max(self.max_val.abs(), self.min_val.abs()) |
| return -m, m |
| else: |
| return self.min_val, self.max_val |
|
|
| |
| |
| |
| class ActQuantizer_0920(nn.Module): |
| def __init__(self, bits: int = 8, symmetric: bool = True, momentum: float = 0.95): |
| super().__init__() |
| assert bits in [4, 8], "Activation bits support 4 or 8." |
| self.bits = bits |
| self.symmetric = symmetric |
| self.qmin = - (2 ** (bits - 1)) |
| self.qmax = (2 ** (bits - 1)) - 1 |
| self.observer = EMAMinMaxObserver_0920(momentum=momentum, symmetric=symmetric) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.training: |
| self.observer.update(x.detach()) |
| min_v, max_v = self.observer.get_range() |
| |
| scale = torch.maximum((max_v - min_v) / (self.qmax - self.qmin), torch.tensor(1e-8, device=x.device)) |
| zero_point = torch.tensor(0.0, device=x.device) |
| |
| x_bar = x / scale |
| x_bar = _clamp_ste(x_bar, self.qmin, self.qmax) |
| x_q = _round_ste(x_bar) |
| x_deq = (x_q - zero_point) * scale |
| return x_deq |
|
|
| def eval(self): |
| super().eval() |
| self.observer.freeze(True) |
| return self |
|
|
| |
| |
| |
| |
| class WeightQuantizer_0920(nn.Module): |
| def __init__(self, bits: int = 1, per_channel: bool = True, channel_dim: int = 0): |
| super().__init__() |
| assert bits in [1, 2, 4, 8], "Weight bits support 1/2/4/8." |
| self.bits = bits |
| self.per_channel = per_channel |
| self.channel_dim = channel_dim |
| if bits > 1: |
| qmin = - (2 ** (bits - 1)) |
| qmax = (2 ** (bits - 1)) - 1 |
| self.register_buffer("qmin", torch.tensor(qmin, dtype=torch.float32)) |
| self.register_buffer("qmax", torch.tensor(qmax, dtype=torch.float32)) |
| self.scale = None |
|
|
| def _build_scale(self, w: torch.Tensor): |
| if self.bits == 1: |
| return |
| reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim())) |
| mean_abs = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel) |
| init_s = 2 * mean_abs / math.sqrt(float(self.qmax.item())) |
| init_s = torch.clamp(init_s, 1e-8, 1e8) |
| self.scale = nn.Parameter(init_s.detach()) |
|
|
| def forward(self, w: torch.Tensor) -> torch.Tensor: |
| if self.bits == 1: |
| reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim())) |
| alpha = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel).detach() |
| w_b = alpha * torch.sign(w) |
| return (w_b - w).detach() + w |
| else: |
| if self.scale is None: |
| self._build_scale(w) |
| s = self.scale |
| w_bar = w / s |
| w_bar = _clamp_ste(w_bar, self.qmin.item(), self.qmax.item()) |
| w_q = _round_ste(w_bar) * s |
| return w_q |
|
|
| |
| |
| |
| class QuantLinear_0920(nn.Module): |
| def __init__(self, org: nn.Linear, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True): |
| super().__init__() |
| self.org = org |
| self.w_bits = w_bits |
| self.a_bits = a_bits |
| self.enable_act = enable_act and (a_bits is not None) |
| self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0) |
| self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.aq is not None: |
| x = self.aq(x) |
| wq = self.wq(self.org.weight) |
| return F.linear(x, wq, self.org.bias) |
|
|
| class _QuantConvNd_0920(nn.Module): |
| conv_fn = None |
| def __init__(self, org, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True): |
| super().__init__() |
| self.org = org |
| self.w_bits = w_bits |
| self.a_bits = a_bits |
| self.enable_act = enable_act and (a_bits is not None) |
| self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0) |
| self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.aq is not None: |
| x = self.aq(x) |
| wq = self.wq(self.org.weight) |
| return self.conv_fn( |
| x, wq, self.org.bias, |
| stride=self.org.stride, padding=self.org.padding, |
| dilation=self.org.dilation, groups=self.org.groups |
| ) |
|
|
| class QuantConv1d_0920(_QuantConvNd_0920): |
| conv_fn = staticmethod(F.conv1d) |
|
|
| class QuantConv2d_0920(_QuantConvNd_0920): |
| conv_fn = staticmethod(F.conv2d) |
|
|
| class QuantConv3d_0920(_QuantConvNd_0920): |
| conv_fn = staticmethod(F.conv3d) |
|
|
| |
| |
| |
| _EXCLUDE_TYPES_0920 = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm) |
| _INCLUDE_TYPES_0920 = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) |
|
|
| def _name_excluded_0920(name: str, exclude_hints: List[str]) -> bool: |
| lname = name.lower() |
| return any(h.lower() in lname for h in exclude_hints) |
|
|
| @torch.no_grad() |
| def apply_quantization_0920(model: nn.Module, |
| enable: bool = False, |
| w_bits: int = 2, |
| a_bits: int = 8, |
| quantize_first_last: bool = False, |
| exclude_name_hints: List[str] = None) -> nn.Module: |
| if not enable: |
| return model |
|
|
| exclude_name_hints = exclude_name_hints or ["cls_head", "embedding.stem"] |
|
|
| |
| candidates = [] |
| for name, module in model.named_modules(): |
| if name == "": |
| continue |
| if isinstance(module, _EXCLUDE_TYPES_0920): |
| continue |
| if not isinstance(module, __INCLUDE_TYPES_0920): |
| continue |
| if _name_excluded_0920(name, exclude_name_hints): |
| continue |
| candidates.append(name) |
|
|
| |
| if not quantize_first_last and len(candidates) >= 2: |
| candidates = candidates[1:-1] |
|
|
| |
| for full_name in candidates: |
| path = full_name.split(".") |
| parent = model |
| for p in path[:-1]: |
| parent = getattr(parent, p) |
| old = getattr(parent, path[-1]) |
|
|
| if isinstance(old, nn.Linear): |
| new_m = QuantLinear_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) |
| elif isinstance(old, nn.Conv1d): |
| new_m = QuantConv1d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) |
| elif isinstance(old, nn.Conv2d): |
| new_m = QuantConv2d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) |
| elif isinstance(old, nn.Conv3d): |
| new_m = QuantConv3d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) |
| else: |
| new_m = old |
| setattr(parent, path[-1], new_m) |
|
|
| return model |
|
|
| |
| |
| |
| def compute_avg_weight_bits_0920(model: nn.Module, include_fp32: bool = True, fp32_bits: int = 32): |
| n_q_params, sum_q_bits = 0, 0 |
| n_total = 0 |
|
|
| for m in model.modules(): |
| if isinstance(m, (QuantLinear_0920, QuantConv1d_0920, QuantConv2d_0920, QuantConv3d_0920)): |
| p = m.org.weight.numel() |
| n_q_params += p |
| sum_q_bits += p * int(m.w_bits) |
|
|
| for p in model.parameters(): |
| n_total += p.numel() |
|
|
| if include_fp32: |
| sum_all_bits = sum_q_bits + (n_total - n_q_params) * fp32_bits |
| else: |
| sum_all_bits = sum_q_bits |
|
|
| avg_q_only = (sum_q_bits / max(1, n_q_params)) if n_q_params > 0 else float("nan") |
| avg_all = (sum_all_bits / max(1, n_total)) if n_total > 0 else float("nan") |
| return avg_q_only, avg_all, (n_q_params, n_total) |
|
|
| def print_avg_bits_0920(model: nn.Module): |
| avg_q, avg_all, (nq, nt) = compute_avg_weight_bits_0920(model, include_fp32=True, fp32_bits=32) |
| pct = 100.0 * nq / max(1, nt) |
| print(f"[QAT-0920] Quantized weight params: {nq}/{nt} ({pct:.2f}%)") |
| print(f"[QAT-0920] Avg weight bits (quantized-only): {avg_q:.3f}") |
| print(f"[QAT-0920] Avg weight bits (including FP32): {avg_all:.3f}") |
|
|
| |
| |
| |
| def install_qat_from_cfg_or_env_0920(model: nn.Module, cfg) -> nn.Module: |
| """ |
| 优先从 cfg.quant0920.* 读取;若 cfg 不支持新增字段,则回退到环境变量: |
| QAT_0920_ENABLE=1 |
| QAT_0920_WBITS=2 |
| QAT_0920_ABITS=8 |
| QAT_0920_FIRST_LAST=0/1 |
| QAT_0920_EXCLUDE="cls_head,embedding.stem" |
| """ |
| def _get(key, default): |
| |
| try: |
| if hasattr(cfg, "quant0920") and hasattr(cfg.quant0920, key): |
| return getattr(cfg.quant0920, key) |
| except Exception: |
| pass |
| |
| env_key = f"QAT_0920_{key.upper()}" |
| v = os.environ.get(env_key, None) |
| if v is None: |
| return default |
| if isinstance(default, bool): |
| return v in ["1", "true", "True", "YES", "yes"] |
| if isinstance(default, int): |
| try: |
| return int(v) |
| except Exception: |
| return default |
| if isinstance(default, list): |
| return [x.strip() for x in v.split(",") if x.strip()] |
| return v |
|
|
| enable = _get("enable", False) |
| w_bits = _get("w_bits", 2) |
| a_bits = _get("a_bits", 8) |
| quantize_first_last = _get("quantize_first_last", False) |
| exclude_name_hints = _get("exclude_name_hints", ["cls_head", "embedding.stem"]) |
|
|
| if enable: |
| model = apply_quantization_0920(model, |
| enable=True, |
| w_bits=w_bits, |
| a_bits=a_bits, |
| quantize_first_last=quantize_first_last, |
| exclude_name_hints=exclude_name_hints) |
| print_avg_bits_0920(model) |
| else: |
| print("[QAT-0920] disabled") |
| return model |
|
|