# -*- coding: utf-8 -*- """ tools/quantize.py 功能: 1) 为 spconv 卷积层安装“零尺寸保护”:若 SparseConvTensor.spatial_shape 任意维 <= 0,则跳过该层 forward,避免 ValueError: your out spatial shape [0, ...] reach zero!!! 2) 预留简易量化占位接口(不强侵入、可与现有配置共存) 用法(建议在模型构建完成后调用一次): from tools.quantize import install_spconv_zero_shape_guard, apply_quantization model = build_model(...) install_spconv_zero_shape_guard(model, verbose=False) # 如果你已有量化流程,可忽略;若希望最小代价打上占位: model = apply_quantization( model, w_bits=2, a_bits=8, quantize_first_last=False, exclude_name_hints=['cls_head', 'embedding.stem'] ) """ from __future__ import annotations import types import warnings from typing import Iterable, Optional, Tuple import torch import torch.nn as nn try: import spconv.pytorch as spconv _HAS_SPCONV = True except Exception: spconv = None _HAS_SPCONV = False # --------------------------- # spconv 零尺寸保护(核心) # --------------------------- _SPCONV_LAYER_TYPES: Tuple[type, ...] = tuple() if _HAS_SPCONV: # 尽可能覆盖常见层;不同版本的 spconv 提供的类名可能略有差异 cand = [] for name in [ "SubMConv3d", "SparseConv3d", "SparseInverseConv3d", "SparseConvTranspose3d", "SubMConv2d", "SparseConv2d", "SparseInverseConv2d", "SparseConvTranspose2d", ]: if hasattr(spconv, name): cand.append(getattr(spconv, name)) _SPCONV_LAYER_TYPES = tuple(cand) def _is_spconv_layer(m: nn.Module) -> bool: return _HAS_SPCONV and isinstance(m, _SPCONV_LAYER_TYPES) def _already_guarded(m: nn.Module) -> bool: return getattr(m, "__spconv_zero_guard_installed__", False) def _mark_guarded(m: nn.Module): setattr(m, "__spconv_zero_guard_installed__", True) def _wrap_spconv_forward(m: nn.Module, verbose: bool = False): """ monkey-patch spconv layer.forward: 若输入 SparseConvTensor 的 spatial_shape 任一维 <= 0,则原样返回输入,跳过该层。 """ assert _is_spconv_layer(m) if _already_guarded(m): return origin_forward = m.forward def guarded_forward(x, *args, **kwargs): try: # x 通常是 SparseConvTensor spatial_shape = getattr(x, "spatial_shape", None) if spatial_shape is not None: # spatial_shape 可能是 torch.Size、list 或 tensor if isinstance(spatial_shape, torch.Tensor): dims = spatial_shape.detach().to("cpu").tolist() else: dims = list(spatial_shape) if any(int(d) <= 0 for d in dims): # 直接跳过该层,返回输入(保持梯度图连通) if verbose: warnings.warn( f"[spconv-zero-guard] Skip {m.__class__.__name__} " f"due to invalid spatial_shape={dims}" ) return x except Exception as e: # 保护逻辑不影响原 forward 的正常执行 warnings.warn(f"[spconv-zero-guard] check failed: {e}") return origin_forward(x, *args, **kwargs) m.forward = types.MethodType(guarded_forward, m) # 绑定到实例 _mark_guarded(m) def install_spconv_zero_shape_guard(model: nn.Module, verbose: bool = False) -> int: """ 递归遍历并给所有 spconv 层安装零尺寸保护。 返回被保护的层数量。 """ if not _HAS_SPCONV: warnings.warn("[spconv-zero-guard] spconv not found; guard is disabled.") return 0 count = 0 for module in model.modules(): if _is_spconv_layer(module) and not _already_guarded(module): _wrap_spconv_forward(module, verbose=verbose) count += 1 if verbose: print(f"[spconv-zero-guard] installed on {count} layers.") return count # -------------------------------- # 极简量化占位(可选,不强侵入) # -------------------------------- class IdentityQuant(nn.Module): """ 最小占位:不做任何数值变换,仅用于维持调用结构一致性。 如果你已有成熟量化流程,可以无视本类。 """ def __init__(self): super().__init__() def forward(self, x): return x def _name_hit(name: str, hints: Iterable[str]) -> bool: name = name or "" for h in hints: if h in name: return True return False def apply_quantization( model: nn.Module, w_bits: int = 2, a_bits: int = 8, quantize_first_last: bool = False, exclude_name_hints: Optional[Iterable[str]] = None, install_guard: bool = True, verbose: bool = False, ) -> nn.Module: """ 一个“零侵入”的占位量化装配: - 不改变权重与计算,仅可选地在激活处插入 IdentityQuant - 同时可选择安装 spconv 零尺寸保护(默认 True) 你可以在现有配置/构建流程中直接调用此函数,不会破坏原有行为。 """ exclude_name_hints = list(exclude_name_hints or []) # 1)(可选)spconv 零尺寸保护 if install_guard: install_spconv_zero_shape_guard(model, verbose=verbose) # 2) 轻量占位量化(仅作为挂点;不改变数值) # 若已有量化逻辑,可把这里改为真实量化模块 repl = {} for name, mod in model.named_modules(): # 跳过顶层模块自己 if name == "": continue if not quantize_first_last: # 粗略地跳过“第一层/最后一层”的常见命名(示例) if _name_hit(name, ["stem", "head", "cls_head", "embedding.stem"]): continue if exclude_name_hints and _name_hit(name, exclude_name_hints): continue # 仅示例:在常见的激活模块后插入 IdentityQuant # 注意:真实工程应更精细地插入到指定位置(conv/linear 之后、norm 之前等) if isinstance(mod, (nn.ReLU, nn.GELU, nn.SiLU, nn.SELU, nn.LeakyReLU)): repl[name] = IdentityQuant() # 把要替换的模块真正替换掉 # 这里采用“父模块 setattr”的方式做浅替换 if repl: for full_name, new_mod in repl.items(): # 找到父模块和叶子名 comps = full_name.split(".") parent = model for c in comps[:-1]: parent = getattr(parent, c) leaf = comps[-1] # 只在没有同名子模块的情况下替换(安全保护) if hasattr(parent, leaf): setattr(parent, leaf, nn.Sequential(getattr(parent, leaf), new_mod)) if verbose: print(f"[quantize] applied placeholder quant: w_bits={w_bits}, a_bits={a_bits}, " f"quantize_first_last={quantize_first_last}, exclude={exclude_name_hints}") return model __all__ = [ "install_spconv_zero_shape_guard", "apply_quantization", ]