| |
| """ |
| tools/quantize.py |
| |
| 功能: |
| 1) 为 spconv 卷积层安装“零尺寸保护”:若 SparseConvTensor.spatial_shape |
| 任意维 <= 0,则跳过该层 forward,避免 |
| ValueError: your out spatial shape [0, ...] reach zero!!! |
| 2) 预留简易量化占位接口(不强侵入、可与现有配置共存) |
| |
| 用法(建议在模型构建完成后调用一次): |
| from tools.quantize import install_spconv_zero_shape_guard, apply_quantization |
| |
| model = build_model(...) |
| install_spconv_zero_shape_guard(model, verbose=False) |
| |
| # 如果你已有量化流程,可忽略;若希望最小代价打上占位: |
| model = apply_quantization( |
| model, |
| w_bits=2, a_bits=8, |
| quantize_first_last=False, |
| exclude_name_hints=['cls_head', 'embedding.stem'] |
| ) |
| """ |
|
|
| from __future__ import annotations |
| import types |
| import warnings |
| from typing import Iterable, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| try: |
| import spconv.pytorch as spconv |
| _HAS_SPCONV = True |
| except Exception: |
| spconv = None |
| _HAS_SPCONV = False |
|
|
|
|
| |
| |
| |
|
|
| _SPCONV_LAYER_TYPES: Tuple[type, ...] = tuple() |
| if _HAS_SPCONV: |
| |
| cand = [] |
| for name in [ |
| "SubMConv3d", "SparseConv3d", "SparseInverseConv3d", "SparseConvTranspose3d", |
| "SubMConv2d", "SparseConv2d", "SparseInverseConv2d", "SparseConvTranspose2d", |
| ]: |
| if hasattr(spconv, name): |
| cand.append(getattr(spconv, name)) |
| _SPCONV_LAYER_TYPES = tuple(cand) |
|
|
|
|
| def _is_spconv_layer(m: nn.Module) -> bool: |
| return _HAS_SPCONV and isinstance(m, _SPCONV_LAYER_TYPES) |
|
|
|
|
| def _already_guarded(m: nn.Module) -> bool: |
| return getattr(m, "__spconv_zero_guard_installed__", False) |
|
|
|
|
| def _mark_guarded(m: nn.Module): |
| setattr(m, "__spconv_zero_guard_installed__", True) |
|
|
|
|
| def _wrap_spconv_forward(m: nn.Module, verbose: bool = False): |
| """ |
| monkey-patch spconv layer.forward: |
| 若输入 SparseConvTensor 的 spatial_shape 任一维 <= 0,则原样返回输入,跳过该层。 |
| """ |
| assert _is_spconv_layer(m) |
| if _already_guarded(m): |
| return |
|
|
| origin_forward = m.forward |
|
|
| def guarded_forward(x, *args, **kwargs): |
| try: |
| |
| spatial_shape = getattr(x, "spatial_shape", None) |
| if spatial_shape is not None: |
| |
| if isinstance(spatial_shape, torch.Tensor): |
| dims = spatial_shape.detach().to("cpu").tolist() |
| else: |
| dims = list(spatial_shape) |
|
|
| if any(int(d) <= 0 for d in dims): |
| |
| if verbose: |
| warnings.warn( |
| f"[spconv-zero-guard] Skip {m.__class__.__name__} " |
| f"due to invalid spatial_shape={dims}" |
| ) |
| return x |
| except Exception as e: |
| |
| warnings.warn(f"[spconv-zero-guard] check failed: {e}") |
|
|
| return origin_forward(x, *args, **kwargs) |
|
|
| m.forward = types.MethodType(guarded_forward, m) |
| _mark_guarded(m) |
|
|
|
|
| def install_spconv_zero_shape_guard(model: nn.Module, verbose: bool = False) -> int: |
| """ |
| 递归遍历并给所有 spconv 层安装零尺寸保护。 |
| 返回被保护的层数量。 |
| """ |
| if not _HAS_SPCONV: |
| warnings.warn("[spconv-zero-guard] spconv not found; guard is disabled.") |
| return 0 |
|
|
| count = 0 |
| for module in model.modules(): |
| if _is_spconv_layer(module) and not _already_guarded(module): |
| _wrap_spconv_forward(module, verbose=verbose) |
| count += 1 |
|
|
| if verbose: |
| print(f"[spconv-zero-guard] installed on {count} layers.") |
| return count |
|
|
|
|
| |
| |
| |
|
|
| class IdentityQuant(nn.Module): |
| """ |
| 最小占位:不做任何数值变换,仅用于维持调用结构一致性。 |
| 如果你已有成熟量化流程,可以无视本类。 |
| """ |
| def __init__(self): |
| super().__init__() |
| def forward(self, x): |
| return x |
|
|
|
|
| def _name_hit(name: str, hints: Iterable[str]) -> bool: |
| name = name or "" |
| for h in hints: |
| if h in name: |
| return True |
| return False |
|
|
|
|
| def apply_quantization( |
| model: nn.Module, |
| w_bits: int = 2, |
| a_bits: int = 8, |
| quantize_first_last: bool = False, |
| exclude_name_hints: Optional[Iterable[str]] = None, |
| install_guard: bool = True, |
| verbose: bool = False, |
| ) -> nn.Module: |
| """ |
| 一个“零侵入”的占位量化装配: |
| - 不改变权重与计算,仅可选地在激活处插入 IdentityQuant |
| - 同时可选择安装 spconv 零尺寸保护(默认 True) |
| 你可以在现有配置/构建流程中直接调用此函数,不会破坏原有行为。 |
| """ |
| exclude_name_hints = list(exclude_name_hints or []) |
|
|
| |
| if install_guard: |
| install_spconv_zero_shape_guard(model, verbose=verbose) |
|
|
| |
| |
| repl = {} |
| for name, mod in model.named_modules(): |
| |
| if name == "": |
| continue |
|
|
| if not quantize_first_last: |
| |
| if _name_hit(name, ["stem", "head", "cls_head", "embedding.stem"]): |
| continue |
|
|
| if exclude_name_hints and _name_hit(name, exclude_name_hints): |
| continue |
|
|
| |
| |
| if isinstance(mod, (nn.ReLU, nn.GELU, nn.SiLU, nn.SELU, nn.LeakyReLU)): |
| repl[name] = IdentityQuant() |
|
|
| |
| |
| if repl: |
| for full_name, new_mod in repl.items(): |
| |
| comps = full_name.split(".") |
| parent = model |
| for c in comps[:-1]: |
| parent = getattr(parent, c) |
| leaf = comps[-1] |
| |
| if hasattr(parent, leaf): |
| setattr(parent, leaf, nn.Sequential(getattr(parent, leaf), new_mod)) |
|
|
| if verbose: |
| print(f"[quantize] applied placeholder quant: w_bits={w_bits}, a_bits={a_bits}, " |
| f"quantize_first_last={quantize_first_last}, exclude={exclude_name_hints}") |
| return model |
|
|
|
|
| __all__ = [ |
| "install_spconv_zero_shape_guard", |
| "apply_quantization", |
| ] |
|
|