biptv3 / server_worktree /tools /quantize.py
YYYYYYUUU's picture
Backup FULL poplab work tree (source, configs, libs, scripts) excl. .pth
08cde47 verified
Raw
History Blame Contribute Delete
7.25 kB
# -*- coding: utf-8 -*-
"""
tools/quantize.py
功能:
1) 为 spconv 卷积层安装“零尺寸保护”:若 SparseConvTensor.spatial_shape
任意维 <= 0,则跳过该层 forward,避免
ValueError: your out spatial shape [0, ...] reach zero!!!
2) 预留简易量化占位接口(不强侵入、可与现有配置共存)
用法(建议在模型构建完成后调用一次):
from tools.quantize import install_spconv_zero_shape_guard, apply_quantization
model = build_model(...)
install_spconv_zero_shape_guard(model, verbose=False)
# 如果你已有量化流程,可忽略;若希望最小代价打上占位:
model = apply_quantization(
model,
w_bits=2, a_bits=8,
quantize_first_last=False,
exclude_name_hints=['cls_head', 'embedding.stem']
)
"""
from __future__ import annotations
import types
import warnings
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
try:
import spconv.pytorch as spconv
_HAS_SPCONV = True
except Exception:
spconv = None
_HAS_SPCONV = False
# ---------------------------
# spconv 零尺寸保护(核心)
# ---------------------------
_SPCONV_LAYER_TYPES: Tuple[type, ...] = tuple()
if _HAS_SPCONV:
# 尽可能覆盖常见层;不同版本的 spconv 提供的类名可能略有差异
cand = []
for name in [
"SubMConv3d", "SparseConv3d", "SparseInverseConv3d", "SparseConvTranspose3d",
"SubMConv2d", "SparseConv2d", "SparseInverseConv2d", "SparseConvTranspose2d",
]:
if hasattr(spconv, name):
cand.append(getattr(spconv, name))
_SPCONV_LAYER_TYPES = tuple(cand)
def _is_spconv_layer(m: nn.Module) -> bool:
return _HAS_SPCONV and isinstance(m, _SPCONV_LAYER_TYPES)
def _already_guarded(m: nn.Module) -> bool:
return getattr(m, "__spconv_zero_guard_installed__", False)
def _mark_guarded(m: nn.Module):
setattr(m, "__spconv_zero_guard_installed__", True)
def _wrap_spconv_forward(m: nn.Module, verbose: bool = False):
"""
monkey-patch spconv layer.forward:
若输入 SparseConvTensor 的 spatial_shape 任一维 <= 0,则原样返回输入,跳过该层。
"""
assert _is_spconv_layer(m)
if _already_guarded(m):
return
origin_forward = m.forward
def guarded_forward(x, *args, **kwargs):
try:
# x 通常是 SparseConvTensor
spatial_shape = getattr(x, "spatial_shape", None)
if spatial_shape is not None:
# spatial_shape 可能是 torch.Size、list 或 tensor
if isinstance(spatial_shape, torch.Tensor):
dims = spatial_shape.detach().to("cpu").tolist()
else:
dims = list(spatial_shape)
if any(int(d) <= 0 for d in dims):
# 直接跳过该层,返回输入(保持梯度图连通)
if verbose:
warnings.warn(
f"[spconv-zero-guard] Skip {m.__class__.__name__} "
f"due to invalid spatial_shape={dims}"
)
return x
except Exception as e:
# 保护逻辑不影响原 forward 的正常执行
warnings.warn(f"[spconv-zero-guard] check failed: {e}")
return origin_forward(x, *args, **kwargs)
m.forward = types.MethodType(guarded_forward, m) # 绑定到实例
_mark_guarded(m)
def install_spconv_zero_shape_guard(model: nn.Module, verbose: bool = False) -> int:
"""
递归遍历并给所有 spconv 层安装零尺寸保护。
返回被保护的层数量。
"""
if not _HAS_SPCONV:
warnings.warn("[spconv-zero-guard] spconv not found; guard is disabled.")
return 0
count = 0
for module in model.modules():
if _is_spconv_layer(module) and not _already_guarded(module):
_wrap_spconv_forward(module, verbose=verbose)
count += 1
if verbose:
print(f"[spconv-zero-guard] installed on {count} layers.")
return count
# --------------------------------
# 极简量化占位(可选,不强侵入)
# --------------------------------
class IdentityQuant(nn.Module):
"""
最小占位:不做任何数值变换,仅用于维持调用结构一致性。
如果你已有成熟量化流程,可以无视本类。
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
def _name_hit(name: str, hints: Iterable[str]) -> bool:
name = name or ""
for h in hints:
if h in name:
return True
return False
def apply_quantization(
model: nn.Module,
w_bits: int = 2,
a_bits: int = 8,
quantize_first_last: bool = False,
exclude_name_hints: Optional[Iterable[str]] = None,
install_guard: bool = True,
verbose: bool = False,
) -> nn.Module:
"""
一个“零侵入”的占位量化装配:
- 不改变权重与计算,仅可选地在激活处插入 IdentityQuant
- 同时可选择安装 spconv 零尺寸保护(默认 True)
你可以在现有配置/构建流程中直接调用此函数,不会破坏原有行为。
"""
exclude_name_hints = list(exclude_name_hints or [])
# 1)(可选)spconv 零尺寸保护
if install_guard:
install_spconv_zero_shape_guard(model, verbose=verbose)
# 2) 轻量占位量化(仅作为挂点;不改变数值)
# 若已有量化逻辑,可把这里改为真实量化模块
repl = {}
for name, mod in model.named_modules():
# 跳过顶层模块自己
if name == "":
continue
if not quantize_first_last:
# 粗略地跳过“第一层/最后一层”的常见命名(示例)
if _name_hit(name, ["stem", "head", "cls_head", "embedding.stem"]):
continue
if exclude_name_hints and _name_hit(name, exclude_name_hints):
continue
# 仅示例:在常见的激活模块后插入 IdentityQuant
# 注意:真实工程应更精细地插入到指定位置(conv/linear 之后、norm 之前等)
if isinstance(mod, (nn.ReLU, nn.GELU, nn.SiLU, nn.SELU, nn.LeakyReLU)):
repl[name] = IdentityQuant()
# 把要替换的模块真正替换掉
# 这里采用“父模块 setattr”的方式做浅替换
if repl:
for full_name, new_mod in repl.items():
# 找到父模块和叶子名
comps = full_name.split(".")
parent = model
for c in comps[:-1]:
parent = getattr(parent, c)
leaf = comps[-1]
# 只在没有同名子模块的情况下替换(安全保护)
if hasattr(parent, leaf):
setattr(parent, leaf, nn.Sequential(getattr(parent, leaf), new_mod))
if verbose:
print(f"[quantize] applied placeholder quant: w_bits={w_bits}, a_bits={a_bits}, "
f"quantize_first_last={quantize_first_last}, exclude={exclude_name_hints}")
return model
__all__ = [
"install_spconv_zero_shape_guard",
"apply_quantization",
]