File size: 7,253 Bytes
08cde47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# -*- coding: utf-8 -*-
"""
tools/quantize.py

功能:
1) 为 spconv 卷积层安装“零尺寸保护”:若 SparseConvTensor.spatial_shape
   任意维 <= 0,则跳过该层 forward,避免
   ValueError: your out spatial shape [0, ...] reach zero!!!
2) 预留简易量化占位接口(不强侵入、可与现有配置共存)

用法(建议在模型构建完成后调用一次):
    from tools.quantize import install_spconv_zero_shape_guard, apply_quantization

    model = build_model(...)
    install_spconv_zero_shape_guard(model, verbose=False)

    # 如果你已有量化流程,可忽略;若希望最小代价打上占位:
    model = apply_quantization(
        model,
        w_bits=2, a_bits=8,
        quantize_first_last=False,
        exclude_name_hints=['cls_head', 'embedding.stem']
    )
"""

from __future__ import annotations
import types
import warnings
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn

try:
    import spconv.pytorch as spconv
    _HAS_SPCONV = True
except Exception:
    spconv = None
    _HAS_SPCONV = False


# ---------------------------
# spconv 零尺寸保护(核心)
# ---------------------------

_SPCONV_LAYER_TYPES: Tuple[type, ...] = tuple()
if _HAS_SPCONV:
    # 尽可能覆盖常见层;不同版本的 spconv 提供的类名可能略有差异
    cand = []
    for name in [
        "SubMConv3d", "SparseConv3d", "SparseInverseConv3d", "SparseConvTranspose3d",
        "SubMConv2d", "SparseConv2d", "SparseInverseConv2d", "SparseConvTranspose2d",
    ]:
        if hasattr(spconv, name):
            cand.append(getattr(spconv, name))
    _SPCONV_LAYER_TYPES = tuple(cand)


def _is_spconv_layer(m: nn.Module) -> bool:
    return _HAS_SPCONV and isinstance(m, _SPCONV_LAYER_TYPES)


def _already_guarded(m: nn.Module) -> bool:
    return getattr(m, "__spconv_zero_guard_installed__", False)


def _mark_guarded(m: nn.Module):
    setattr(m, "__spconv_zero_guard_installed__", True)


def _wrap_spconv_forward(m: nn.Module, verbose: bool = False):
    """
    monkey-patch spconv layer.forward:
    若输入 SparseConvTensor 的 spatial_shape 任一维 <= 0,则原样返回输入,跳过该层。
    """
    assert _is_spconv_layer(m)
    if _already_guarded(m):
        return

    origin_forward = m.forward

    def guarded_forward(x, *args, **kwargs):
        try:
            # x 通常是 SparseConvTensor
            spatial_shape = getattr(x, "spatial_shape", None)
            if spatial_shape is not None:
                # spatial_shape 可能是 torch.Size、list 或 tensor
                if isinstance(spatial_shape, torch.Tensor):
                    dims = spatial_shape.detach().to("cpu").tolist()
                else:
                    dims = list(spatial_shape)

                if any(int(d) <= 0 for d in dims):
                    # 直接跳过该层,返回输入(保持梯度图连通)
                    if verbose:
                        warnings.warn(
                            f"[spconv-zero-guard] Skip {m.__class__.__name__} "
                            f"due to invalid spatial_shape={dims}"
                        )
                    return x
        except Exception as e:
            # 保护逻辑不影响原 forward 的正常执行
            warnings.warn(f"[spconv-zero-guard] check failed: {e}")

        return origin_forward(x, *args, **kwargs)

    m.forward = types.MethodType(guarded_forward, m)  # 绑定到实例
    _mark_guarded(m)


def install_spconv_zero_shape_guard(model: nn.Module, verbose: bool = False) -> int:
    """
    递归遍历并给所有 spconv 层安装零尺寸保护。
    返回被保护的层数量。
    """
    if not _HAS_SPCONV:
        warnings.warn("[spconv-zero-guard] spconv not found; guard is disabled.")
        return 0

    count = 0
    for module in model.modules():
        if _is_spconv_layer(module) and not _already_guarded(module):
            _wrap_spconv_forward(module, verbose=verbose)
            count += 1

    if verbose:
        print(f"[spconv-zero-guard] installed on {count} layers.")
    return count


# --------------------------------
# 极简量化占位(可选,不强侵入)
# --------------------------------

class IdentityQuant(nn.Module):
    """
    最小占位:不做任何数值变换,仅用于维持调用结构一致性。
    如果你已有成熟量化流程,可以无视本类。
    """
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x


def _name_hit(name: str, hints: Iterable[str]) -> bool:
    name = name or ""
    for h in hints:
        if h in name:
            return True
    return False


def apply_quantization(
    model: nn.Module,
    w_bits: int = 2,
    a_bits: int = 8,
    quantize_first_last: bool = False,
    exclude_name_hints: Optional[Iterable[str]] = None,
    install_guard: bool = True,
    verbose: bool = False,
) -> nn.Module:
    """
    一个“零侵入”的占位量化装配:
      - 不改变权重与计算,仅可选地在激活处插入 IdentityQuant
      - 同时可选择安装 spconv 零尺寸保护(默认 True)
    你可以在现有配置/构建流程中直接调用此函数,不会破坏原有行为。
    """
    exclude_name_hints = list(exclude_name_hints or [])

    # 1)(可选)spconv 零尺寸保护
    if install_guard:
        install_spconv_zero_shape_guard(model, verbose=verbose)

    # 2) 轻量占位量化(仅作为挂点;不改变数值)
    #    若已有量化逻辑,可把这里改为真实量化模块
    repl = {}
    for name, mod in model.named_modules():
        # 跳过顶层模块自己
        if name == "":
            continue

        if not quantize_first_last:
            # 粗略地跳过“第一层/最后一层”的常见命名(示例)
            if _name_hit(name, ["stem", "head", "cls_head", "embedding.stem"]):
                continue

        if exclude_name_hints and _name_hit(name, exclude_name_hints):
            continue

        # 仅示例:在常见的激活模块后插入 IdentityQuant
        # 注意:真实工程应更精细地插入到指定位置(conv/linear 之后、norm 之前等)
        if isinstance(mod, (nn.ReLU, nn.GELU, nn.SiLU, nn.SELU, nn.LeakyReLU)):
            repl[name] = IdentityQuant()

    # 把要替换的模块真正替换掉
    # 这里采用“父模块 setattr”的方式做浅替换
    if repl:
        for full_name, new_mod in repl.items():
            # 找到父模块和叶子名
            comps = full_name.split(".")
            parent = model
            for c in comps[:-1]:
                parent = getattr(parent, c)
            leaf = comps[-1]
            # 只在没有同名子模块的情况下替换(安全保护)
            if hasattr(parent, leaf):
                setattr(parent, leaf, nn.Sequential(getattr(parent, leaf), new_mod))

    if verbose:
        print(f"[quantize] applied placeholder quant: w_bits={w_bits}, a_bits={a_bits}, "
              f"quantize_first_last={quantize_first_last}, exclude={exclude_name_hints}")
    return model


__all__ = [
    "install_spconv_zero_shape_guard",
    "apply_quantization",
]