File size: 7,253 Bytes
08cde47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | # -*- coding: utf-8 -*-
"""
tools/quantize.py
功能:
1) 为 spconv 卷积层安装“零尺寸保护”:若 SparseConvTensor.spatial_shape
任意维 <= 0,则跳过该层 forward,避免
ValueError: your out spatial shape [0, ...] reach zero!!!
2) 预留简易量化占位接口(不强侵入、可与现有配置共存)
用法(建议在模型构建完成后调用一次):
from tools.quantize import install_spconv_zero_shape_guard, apply_quantization
model = build_model(...)
install_spconv_zero_shape_guard(model, verbose=False)
# 如果你已有量化流程,可忽略;若希望最小代价打上占位:
model = apply_quantization(
model,
w_bits=2, a_bits=8,
quantize_first_last=False,
exclude_name_hints=['cls_head', 'embedding.stem']
)
"""
from __future__ import annotations
import types
import warnings
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
try:
import spconv.pytorch as spconv
_HAS_SPCONV = True
except Exception:
spconv = None
_HAS_SPCONV = False
# ---------------------------
# spconv 零尺寸保护(核心)
# ---------------------------
_SPCONV_LAYER_TYPES: Tuple[type, ...] = tuple()
if _HAS_SPCONV:
# 尽可能覆盖常见层;不同版本的 spconv 提供的类名可能略有差异
cand = []
for name in [
"SubMConv3d", "SparseConv3d", "SparseInverseConv3d", "SparseConvTranspose3d",
"SubMConv2d", "SparseConv2d", "SparseInverseConv2d", "SparseConvTranspose2d",
]:
if hasattr(spconv, name):
cand.append(getattr(spconv, name))
_SPCONV_LAYER_TYPES = tuple(cand)
def _is_spconv_layer(m: nn.Module) -> bool:
return _HAS_SPCONV and isinstance(m, _SPCONV_LAYER_TYPES)
def _already_guarded(m: nn.Module) -> bool:
return getattr(m, "__spconv_zero_guard_installed__", False)
def _mark_guarded(m: nn.Module):
setattr(m, "__spconv_zero_guard_installed__", True)
def _wrap_spconv_forward(m: nn.Module, verbose: bool = False):
"""
monkey-patch spconv layer.forward:
若输入 SparseConvTensor 的 spatial_shape 任一维 <= 0,则原样返回输入,跳过该层。
"""
assert _is_spconv_layer(m)
if _already_guarded(m):
return
origin_forward = m.forward
def guarded_forward(x, *args, **kwargs):
try:
# x 通常是 SparseConvTensor
spatial_shape = getattr(x, "spatial_shape", None)
if spatial_shape is not None:
# spatial_shape 可能是 torch.Size、list 或 tensor
if isinstance(spatial_shape, torch.Tensor):
dims = spatial_shape.detach().to("cpu").tolist()
else:
dims = list(spatial_shape)
if any(int(d) <= 0 for d in dims):
# 直接跳过该层,返回输入(保持梯度图连通)
if verbose:
warnings.warn(
f"[spconv-zero-guard] Skip {m.__class__.__name__} "
f"due to invalid spatial_shape={dims}"
)
return x
except Exception as e:
# 保护逻辑不影响原 forward 的正常执行
warnings.warn(f"[spconv-zero-guard] check failed: {e}")
return origin_forward(x, *args, **kwargs)
m.forward = types.MethodType(guarded_forward, m) # 绑定到实例
_mark_guarded(m)
def install_spconv_zero_shape_guard(model: nn.Module, verbose: bool = False) -> int:
"""
递归遍历并给所有 spconv 层安装零尺寸保护。
返回被保护的层数量。
"""
if not _HAS_SPCONV:
warnings.warn("[spconv-zero-guard] spconv not found; guard is disabled.")
return 0
count = 0
for module in model.modules():
if _is_spconv_layer(module) and not _already_guarded(module):
_wrap_spconv_forward(module, verbose=verbose)
count += 1
if verbose:
print(f"[spconv-zero-guard] installed on {count} layers.")
return count
# --------------------------------
# 极简量化占位(可选,不强侵入)
# --------------------------------
class IdentityQuant(nn.Module):
"""
最小占位:不做任何数值变换,仅用于维持调用结构一致性。
如果你已有成熟量化流程,可以无视本类。
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
def _name_hit(name: str, hints: Iterable[str]) -> bool:
name = name or ""
for h in hints:
if h in name:
return True
return False
def apply_quantization(
model: nn.Module,
w_bits: int = 2,
a_bits: int = 8,
quantize_first_last: bool = False,
exclude_name_hints: Optional[Iterable[str]] = None,
install_guard: bool = True,
verbose: bool = False,
) -> nn.Module:
"""
一个“零侵入”的占位量化装配:
- 不改变权重与计算,仅可选地在激活处插入 IdentityQuant
- 同时可选择安装 spconv 零尺寸保护(默认 True)
你可以在现有配置/构建流程中直接调用此函数,不会破坏原有行为。
"""
exclude_name_hints = list(exclude_name_hints or [])
# 1)(可选)spconv 零尺寸保护
if install_guard:
install_spconv_zero_shape_guard(model, verbose=verbose)
# 2) 轻量占位量化(仅作为挂点;不改变数值)
# 若已有量化逻辑,可把这里改为真实量化模块
repl = {}
for name, mod in model.named_modules():
# 跳过顶层模块自己
if name == "":
continue
if not quantize_first_last:
# 粗略地跳过“第一层/最后一层”的常见命名(示例)
if _name_hit(name, ["stem", "head", "cls_head", "embedding.stem"]):
continue
if exclude_name_hints and _name_hit(name, exclude_name_hints):
continue
# 仅示例:在常见的激活模块后插入 IdentityQuant
# 注意:真实工程应更精细地插入到指定位置(conv/linear 之后、norm 之前等)
if isinstance(mod, (nn.ReLU, nn.GELU, nn.SiLU, nn.SELU, nn.LeakyReLU)):
repl[name] = IdentityQuant()
# 把要替换的模块真正替换掉
# 这里采用“父模块 setattr”的方式做浅替换
if repl:
for full_name, new_mod in repl.items():
# 找到父模块和叶子名
comps = full_name.split(".")
parent = model
for c in comps[:-1]:
parent = getattr(parent, c)
leaf = comps[-1]
# 只在没有同名子模块的情况下替换(安全保护)
if hasattr(parent, leaf):
setattr(parent, leaf, nn.Sequential(getattr(parent, leaf), new_mod))
if verbose:
print(f"[quantize] applied placeholder quant: w_bits={w_bits}, a_bits={a_bits}, "
f"quantize_first_last={quantize_first_last}, exclude={exclude_name_hints}")
return model
__all__ = [
"install_spconv_zero_shape_guard",
"apply_quantization",
]
|