# pointcept/models/utils/quant_0920.py import math from typing import List, Tuple import os import torch import torch.nn as nn import torch.nn.functional as F # ----------------------------- # 基础:STE # ----------------------------- def _round_ste(x: torch.Tensor) -> torch.Tensor: return (x.round() - x).detach() + x def _clamp_ste(x: torch.Tensor, minv: float, maxv: float) -> torch.Tensor: return (x.clamp(minv, maxv) - x).detach() + x # ----------------------------- # 激活观察者(EMA) # ----------------------------- class EMAMinMaxObserver_0920(nn.Module): def __init__(self, momentum: float = 0.95, symmetric: bool = True): super().__init__() self.momentum = momentum self.symmetric = symmetric self.register_buffer("min_val", torch.tensor(float("+inf"))) self.register_buffer("max_val", torch.tensor(float("-inf"))) self.frozen = False @torch.no_grad() def update(self, x: torch.Tensor): if self.frozen: return cur_min = x.min().detach() cur_max = x.max().detach() if torch.isinf(self.min_val): self.min_val.copy_(cur_min) self.max_val.copy_(cur_max) else: self.min_val.mul_(self.momentum).add_(cur_min * (1 - self.momentum)) self.max_val.mul_(self.momentum).add_(cur_max * (1 - self.momentum)) def freeze(self, flag: bool = True): self.frozen = flag def get_range(self) -> Tuple[torch.Tensor, torch.Tensor]: if self.symmetric: m = torch.max(self.max_val.abs(), self.min_val.abs()) return -m, m else: return self.min_val, self.max_val # ----------------------------- # 激活量化器(对称,per-tensor) # ----------------------------- class ActQuantizer_0920(nn.Module): def __init__(self, bits: int = 8, symmetric: bool = True, momentum: float = 0.95): super().__init__() assert bits in [4, 8], "Activation bits support 4 or 8." self.bits = bits self.symmetric = symmetric self.qmin = - (2 ** (bits - 1)) self.qmax = (2 ** (bits - 1)) - 1 self.observer = EMAMinMaxObserver_0920(momentum=momentum, symmetric=symmetric) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: self.observer.update(x.detach()) min_v, max_v = self.observer.get_range() # 避免除 0 scale = torch.maximum((max_v - min_v) / (self.qmax - self.qmin), torch.tensor(1e-8, device=x.device)) zero_point = torch.tensor(0.0, device=x.device) # 对称量化 zp=0 # 伪量化(STE) x_bar = x / scale x_bar = _clamp_ste(x_bar, self.qmin, self.qmax) x_q = _round_ste(x_bar) x_deq = (x_q - zero_point) * scale return x_deq def eval(self): super().eval() self.observer.freeze(True) return self # ----------------------------- # 权重量化器(对称,per-channel) # 1bit: alpha*sign(W);2/4/8bit: LSQ 风格 # ----------------------------- class WeightQuantizer_0920(nn.Module): def __init__(self, bits: int = 1, per_channel: bool = True, channel_dim: int = 0): super().__init__() assert bits in [1, 2, 4, 8], "Weight bits support 1/2/4/8." self.bits = bits self.per_channel = per_channel self.channel_dim = channel_dim if bits > 1: qmin = - (2 ** (bits - 1)) qmax = (2 ** (bits - 1)) - 1 self.register_buffer("qmin", torch.tensor(qmin, dtype=torch.float32)) self.register_buffer("qmax", torch.tensor(qmax, dtype=torch.float32)) self.scale = None # 延迟初始化 def _build_scale(self, w: torch.Tensor): if self.bits == 1: return reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim())) mean_abs = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel) init_s = 2 * mean_abs / math.sqrt(float(self.qmax.item())) init_s = torch.clamp(init_s, 1e-8, 1e8) self.scale = nn.Parameter(init_s.detach()) def forward(self, w: torch.Tensor) -> torch.Tensor: if self.bits == 1: reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim())) alpha = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel).detach() w_b = alpha * torch.sign(w) return (w_b - w).detach() + w # STE else: if self.scale is None: self._build_scale(w) s = self.scale w_bar = w / s w_bar = _clamp_ste(w_bar, self.qmin.item(), self.qmax.item()) w_q = _round_ste(w_bar) * s return w_q # ----------------------------- # 包裹层:Linear / Conv{1,2,3}d # ----------------------------- class QuantLinear_0920(nn.Module): def __init__(self, org: nn.Linear, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True): super().__init__() self.org = org self.w_bits = w_bits self.a_bits = a_bits self.enable_act = enable_act and (a_bits is not None) self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0) self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None def forward(self, x: torch.Tensor) -> torch.Tensor: if self.aq is not None: x = self.aq(x) wq = self.wq(self.org.weight) return F.linear(x, wq, self.org.bias) class _QuantConvNd_0920(nn.Module): conv_fn = None def __init__(self, org, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True): super().__init__() self.org = org self.w_bits = w_bits self.a_bits = a_bits self.enable_act = enable_act and (a_bits is not None) self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0) self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None def forward(self, x: torch.Tensor) -> torch.Tensor: if self.aq is not None: x = self.aq(x) wq = self.wq(self.org.weight) return self.conv_fn( x, wq, self.org.bias, stride=self.org.stride, padding=self.org.padding, dilation=self.org.dilation, groups=self.org.groups ) class QuantConv1d_0920(_QuantConvNd_0920): conv_fn = staticmethod(F.conv1d) class QuantConv2d_0920(_QuantConvNd_0920): conv_fn = staticmethod(F.conv2d) class QuantConv3d_0920(_QuantConvNd_0920): conv_fn = staticmethod(F.conv3d) # ----------------------------- # 模型替换 # ----------------------------- _EXCLUDE_TYPES_0920 = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm) _INCLUDE_TYPES_0920 = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) def _name_excluded_0920(name: str, exclude_hints: List[str]) -> bool: lname = name.lower() return any(h.lower() in lname for h in exclude_hints) @torch.no_grad() def apply_quantization_0920(model: nn.Module, enable: bool = False, w_bits: int = 2, a_bits: int = 8, quantize_first_last: bool = False, exclude_name_hints: List[str] = None) -> nn.Module: if not enable: return model exclude_name_hints = exclude_name_hints or ["cls_head", "embedding.stem"] # 收集候选 candidates = [] for name, module in model.named_modules(): if name == "": continue if isinstance(module, _EXCLUDE_TYPES_0920): continue if not isinstance(module, __INCLUDE_TYPES_0920): continue if _name_excluded_0920(name, exclude_name_hints): continue candidates.append(name) # 不量化首/末层(默认) if not quantize_first_last and len(candidates) >= 2: candidates = candidates[1:-1] # 替换 for full_name in candidates: path = full_name.split(".") parent = model for p in path[:-1]: parent = getattr(parent, p) old = getattr(parent, path[-1]) if isinstance(old, nn.Linear): new_m = QuantLinear_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) elif isinstance(old, nn.Conv1d): new_m = QuantConv1d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) elif isinstance(old, nn.Conv2d): new_m = QuantConv2d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) elif isinstance(old, nn.Conv3d): new_m = QuantConv3d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True) else: new_m = old setattr(parent, path[-1], new_m) return model # ----------------------------- # 统计 Average Bits(按参数量加权) # ----------------------------- def compute_avg_weight_bits_0920(model: nn.Module, include_fp32: bool = True, fp32_bits: int = 32): n_q_params, sum_q_bits = 0, 0 n_total = 0 for m in model.modules(): if isinstance(m, (QuantLinear_0920, QuantConv1d_0920, QuantConv2d_0920, QuantConv3d_0920)): p = m.org.weight.numel() n_q_params += p sum_q_bits += p * int(m.w_bits) for p in model.parameters(): n_total += p.numel() if include_fp32: sum_all_bits = sum_q_bits + (n_total - n_q_params) * fp32_bits else: sum_all_bits = sum_q_bits avg_q_only = (sum_q_bits / max(1, n_q_params)) if n_q_params > 0 else float("nan") avg_all = (sum_all_bits / max(1, n_total)) if n_total > 0 else float("nan") return avg_q_only, avg_all, (n_q_params, n_total) def print_avg_bits_0920(model: nn.Module): avg_q, avg_all, (nq, nt) = compute_avg_weight_bits_0920(model, include_fp32=True, fp32_bits=32) pct = 100.0 * nq / max(1, nt) print(f"[QAT-0920] Quantized weight params: {nq}/{nt} ({pct:.2f}%)") print(f"[QAT-0920] Avg weight bits (quantized-only): {avg_q:.3f}") print(f"[QAT-0920] Avg weight bits (including FP32): {avg_all:.3f}") # ----------------------------- # 从 cfg 或环境变量读取设置(便于最小侵入集成) # ----------------------------- def install_qat_from_cfg_or_env_0920(model: nn.Module, cfg) -> nn.Module: """ 优先从 cfg.quant0920.* 读取;若 cfg 不支持新增字段,则回退到环境变量: QAT_0920_ENABLE=1 QAT_0920_WBITS=2 QAT_0920_ABITS=8 QAT_0920_FIRST_LAST=0/1 QAT_0920_EXCLUDE="cls_head,embedding.stem" """ def _get(key, default): # cfg 路径:cfg.quant0920. try: if hasattr(cfg, "quant0920") and hasattr(cfg.quant0920, key): return getattr(cfg.quant0920, key) except Exception: pass # env 回退 env_key = f"QAT_0920_{key.upper()}" v = os.environ.get(env_key, None) if v is None: return default if isinstance(default, bool): return v in ["1", "true", "True", "YES", "yes"] if isinstance(default, int): try: return int(v) except Exception: return default if isinstance(default, list): return [x.strip() for x in v.split(",") if x.strip()] return v enable = _get("enable", False) w_bits = _get("w_bits", 2) a_bits = _get("a_bits", 8) quantize_first_last = _get("quantize_first_last", False) exclude_name_hints = _get("exclude_name_hints", ["cls_head", "embedding.stem"]) if enable: model = apply_quantization_0920(model, enable=True, w_bits=w_bits, a_bits=a_bits, quantize_first_last=quantize_first_last, exclude_name_hints=exclude_name_hints) print_avg_bits_0920(model) else: print("[QAT-0920] disabled") return model