YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
12.2 kB
# pointcept/models/utils/quant_0920.py
import math
from typing import List, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------
# 基础:STE
# -----------------------------
def _round_ste(x: torch.Tensor) -> torch.Tensor:
return (x.round() - x).detach() + x
def _clamp_ste(x: torch.Tensor, minv: float, maxv: float) -> torch.Tensor:
return (x.clamp(minv, maxv) - x).detach() + x
# -----------------------------
# 激活观察者(EMA)
# -----------------------------
class EMAMinMaxObserver_0920(nn.Module):
def __init__(self, momentum: float = 0.95, symmetric: bool = True):
super().__init__()
self.momentum = momentum
self.symmetric = symmetric
self.register_buffer("min_val", torch.tensor(float("+inf")))
self.register_buffer("max_val", torch.tensor(float("-inf")))
self.frozen = False
@torch.no_grad()
def update(self, x: torch.Tensor):
if self.frozen:
return
cur_min = x.min().detach()
cur_max = x.max().detach()
if torch.isinf(self.min_val):
self.min_val.copy_(cur_min)
self.max_val.copy_(cur_max)
else:
self.min_val.mul_(self.momentum).add_(cur_min * (1 - self.momentum))
self.max_val.mul_(self.momentum).add_(cur_max * (1 - self.momentum))
def freeze(self, flag: bool = True):
self.frozen = flag
def get_range(self) -> Tuple[torch.Tensor, torch.Tensor]:
if self.symmetric:
m = torch.max(self.max_val.abs(), self.min_val.abs())
return -m, m
else:
return self.min_val, self.max_val
# -----------------------------
# 激活量化器(对称,per-tensor)
# -----------------------------
class ActQuantizer_0920(nn.Module):
def __init__(self, bits: int = 8, symmetric: bool = True, momentum: float = 0.95):
super().__init__()
assert bits in [4, 8], "Activation bits support 4 or 8."
self.bits = bits
self.symmetric = symmetric
self.qmin = - (2 ** (bits - 1))
self.qmax = (2 ** (bits - 1)) - 1
self.observer = EMAMinMaxObserver_0920(momentum=momentum, symmetric=symmetric)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
self.observer.update(x.detach())
min_v, max_v = self.observer.get_range()
# 避免除 0
scale = torch.maximum((max_v - min_v) / (self.qmax - self.qmin), torch.tensor(1e-8, device=x.device))
zero_point = torch.tensor(0.0, device=x.device) # 对称量化 zp=0
# 伪量化(STE)
x_bar = x / scale
x_bar = _clamp_ste(x_bar, self.qmin, self.qmax)
x_q = _round_ste(x_bar)
x_deq = (x_q - zero_point) * scale
return x_deq
def eval(self):
super().eval()
self.observer.freeze(True)
return self
# -----------------------------
# 权重量化器(对称,per-channel)
# 1bit: alpha*sign(W);2/4/8bit: LSQ 风格
# -----------------------------
class WeightQuantizer_0920(nn.Module):
def __init__(self, bits: int = 1, per_channel: bool = True, channel_dim: int = 0):
super().__init__()
assert bits in [1, 2, 4, 8], "Weight bits support 1/2/4/8."
self.bits = bits
self.per_channel = per_channel
self.channel_dim = channel_dim
if bits > 1:
qmin = - (2 ** (bits - 1))
qmax = (2 ** (bits - 1)) - 1
self.register_buffer("qmin", torch.tensor(qmin, dtype=torch.float32))
self.register_buffer("qmax", torch.tensor(qmax, dtype=torch.float32))
self.scale = None # 延迟初始化
def _build_scale(self, w: torch.Tensor):
if self.bits == 1:
return
reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim()))
mean_abs = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel)
init_s = 2 * mean_abs / math.sqrt(float(self.qmax.item()))
init_s = torch.clamp(init_s, 1e-8, 1e8)
self.scale = nn.Parameter(init_s.detach())
def forward(self, w: torch.Tensor) -> torch.Tensor:
if self.bits == 1:
reduce_dims = [i for i in range(w.dim()) if i != self.channel_dim] if self.per_channel else list(range(w.dim()))
alpha = w.abs().mean(dim=reduce_dims, keepdim=self.per_channel).detach()
w_b = alpha * torch.sign(w)
return (w_b - w).detach() + w # STE
else:
if self.scale is None:
self._build_scale(w)
s = self.scale
w_bar = w / s
w_bar = _clamp_ste(w_bar, self.qmin.item(), self.qmax.item())
w_q = _round_ste(w_bar) * s
return w_q
# -----------------------------
# 包裹层:Linear / Conv{1,2,3}d
# -----------------------------
class QuantLinear_0920(nn.Module):
def __init__(self, org: nn.Linear, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True):
super().__init__()
self.org = org
self.w_bits = w_bits
self.a_bits = a_bits
self.enable_act = enable_act and (a_bits is not None)
self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0)
self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.aq is not None:
x = self.aq(x)
wq = self.wq(self.org.weight)
return F.linear(x, wq, self.org.bias)
class _QuantConvNd_0920(nn.Module):
conv_fn = None
def __init__(self, org, w_bits: int = 1, a_bits: int = 8, enable_act: bool = True):
super().__init__()
self.org = org
self.w_bits = w_bits
self.a_bits = a_bits
self.enable_act = enable_act and (a_bits is not None)
self.wq = WeightQuantizer_0920(bits=w_bits, per_channel=True, channel_dim=0)
self.aq = ActQuantizer_0920(bits=a_bits) if self.enable_act else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.aq is not None:
x = self.aq(x)
wq = self.wq(self.org.weight)
return self.conv_fn(
x, wq, self.org.bias,
stride=self.org.stride, padding=self.org.padding,
dilation=self.org.dilation, groups=self.org.groups
)
class QuantConv1d_0920(_QuantConvNd_0920):
conv_fn = staticmethod(F.conv1d)
class QuantConv2d_0920(_QuantConvNd_0920):
conv_fn = staticmethod(F.conv2d)
class QuantConv3d_0920(_QuantConvNd_0920):
conv_fn = staticmethod(F.conv3d)
# -----------------------------
# 模型替换
# -----------------------------
_EXCLUDE_TYPES_0920 = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm)
_INCLUDE_TYPES_0920 = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
def _name_excluded_0920(name: str, exclude_hints: List[str]) -> bool:
lname = name.lower()
return any(h.lower() in lname for h in exclude_hints)
@torch.no_grad()
def apply_quantization_0920(model: nn.Module,
enable: bool = False,
w_bits: int = 2,
a_bits: int = 8,
quantize_first_last: bool = False,
exclude_name_hints: List[str] = None) -> nn.Module:
if not enable:
return model
exclude_name_hints = exclude_name_hints or ["cls_head", "embedding.stem"]
# 收集候选
candidates = []
for name, module in model.named_modules():
if name == "":
continue
if isinstance(module, _EXCLUDE_TYPES_0920):
continue
if not isinstance(module, __INCLUDE_TYPES_0920):
continue
if _name_excluded_0920(name, exclude_name_hints):
continue
candidates.append(name)
# 不量化首/末层(默认)
if not quantize_first_last and len(candidates) >= 2:
candidates = candidates[1:-1]
# 替换
for full_name in candidates:
path = full_name.split(".")
parent = model
for p in path[:-1]:
parent = getattr(parent, p)
old = getattr(parent, path[-1])
if isinstance(old, nn.Linear):
new_m = QuantLinear_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True)
elif isinstance(old, nn.Conv1d):
new_m = QuantConv1d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True)
elif isinstance(old, nn.Conv2d):
new_m = QuantConv2d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True)
elif isinstance(old, nn.Conv3d):
new_m = QuantConv3d_0920(old, w_bits=w_bits, a_bits=a_bits, enable_act=True)
else:
new_m = old
setattr(parent, path[-1], new_m)
return model
# -----------------------------
# 统计 Average Bits(按参数量加权)
# -----------------------------
def compute_avg_weight_bits_0920(model: nn.Module, include_fp32: bool = True, fp32_bits: int = 32):
n_q_params, sum_q_bits = 0, 0
n_total = 0
for m in model.modules():
if isinstance(m, (QuantLinear_0920, QuantConv1d_0920, QuantConv2d_0920, QuantConv3d_0920)):
p = m.org.weight.numel()
n_q_params += p
sum_q_bits += p * int(m.w_bits)
for p in model.parameters():
n_total += p.numel()
if include_fp32:
sum_all_bits = sum_q_bits + (n_total - n_q_params) * fp32_bits
else:
sum_all_bits = sum_q_bits
avg_q_only = (sum_q_bits / max(1, n_q_params)) if n_q_params > 0 else float("nan")
avg_all = (sum_all_bits / max(1, n_total)) if n_total > 0 else float("nan")
return avg_q_only, avg_all, (n_q_params, n_total)
def print_avg_bits_0920(model: nn.Module):
avg_q, avg_all, (nq, nt) = compute_avg_weight_bits_0920(model, include_fp32=True, fp32_bits=32)
pct = 100.0 * nq / max(1, nt)
print(f"[QAT-0920] Quantized weight params: {nq}/{nt} ({pct:.2f}%)")
print(f"[QAT-0920] Avg weight bits (quantized-only): {avg_q:.3f}")
print(f"[QAT-0920] Avg weight bits (including FP32): {avg_all:.3f}")
# -----------------------------
# 从 cfg 或环境变量读取设置(便于最小侵入集成)
# -----------------------------
def install_qat_from_cfg_or_env_0920(model: nn.Module, cfg) -> nn.Module:
"""
优先从 cfg.quant0920.* 读取;若 cfg 不支持新增字段,则回退到环境变量:
QAT_0920_ENABLE=1
QAT_0920_WBITS=2
QAT_0920_ABITS=8
QAT_0920_FIRST_LAST=0/1
QAT_0920_EXCLUDE="cls_head,embedding.stem"
"""
def _get(key, default):
# cfg 路径:cfg.quant0920.<key>
try:
if hasattr(cfg, "quant0920") and hasattr(cfg.quant0920, key):
return getattr(cfg.quant0920, key)
except Exception:
pass
# env 回退
env_key = f"QAT_0920_{key.upper()}"
v = os.environ.get(env_key, None)
if v is None:
return default
if isinstance(default, bool):
return v in ["1", "true", "True", "YES", "yes"]
if isinstance(default, int):
try:
return int(v)
except Exception:
return default
if isinstance(default, list):
return [x.strip() for x in v.split(",") if x.strip()]
return v
enable = _get("enable", False)
w_bits = _get("w_bits", 2)
a_bits = _get("a_bits", 8)
quantize_first_last = _get("quantize_first_last", False)
exclude_name_hints = _get("exclude_name_hints", ["cls_head", "embedding.stem"])
if enable:
model = apply_quantization_0920(model,
enable=True,
w_bits=w_bits,
a_bits=a_bits,
quantize_first_last=quantize_first_last,
exclude_name_hints=exclude_name_hints)
print_avg_bits_0920(model)
else:
print("[QAT-0920] disabled")
return model