File size: 2,326 Bytes
3499c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
将 20 类语义头的权重预处理为 S3DIS 可用的初始化权重:
- 丢弃 seg_head.*(分类头),只保留 backbone 等共享部分
- 可选:安全去掉可能的 "module." 前缀
- 不改动其它 shape 匹配的权重

用法:
  python scripts/make_s3dis_init_ckpt.py \
    --src exp/default/model/model_last.pth \
    --dst exp/default/model/model_last_nohead.pth \
    --strip-module
"""
import argparse, os, sys, torch

def maybe_get_statedict(obj):
    if isinstance(obj, dict):
        for k in ("state_dict", "model", "ema_state_dict"):
            if k in obj and isinstance(obj[k], dict):
                return obj[k]
    return obj

def strip_module_prefix(sd):
    # 把 "module." 前缀去掉,避免 DDP 保存的前缀影响加载
    out = {}
    for k, v in sd.items():
        if k.startswith("module."):
            out[k[len("module."):]] = v
        else:
            out[k] = v
    return out

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--src", required=True, help="源权重路径(20类头)")
    ap.add_argument("--dst", required=True, help="输出路径(去掉 seg_head 后)")
    ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀")
    args = ap.parse_args()

    # 新版 PyTorch 建议 weights_only=True;但为兼容历史文件,先保持默认
    ckpt = torch.load(args.src, map_location="cpu")
    sd = maybe_get_statedict(ckpt)
    if not isinstance(sd, dict):
        print("❌ 未能解析到 state_dict;请检查权重文件。", file=sys.stderr); sys.exit(1)

    if args.strip_module:
        sd = strip_module_prefix(sd)

    kept, dropped = {}, []
    for k, v in sd.items():
        if k.startswith("seg_head."):
            dropped.append((k, tuple(v.shape)))
        else:
            kept[k] = v

    os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True)
    torch.save(kept, args.dst)

    print(f"✅ 完成:保留 {len(kept)} 个张量;丢弃 {len(dropped)} 个 seg_head.*")
    if dropped:
        print("丢弃清单(名称, 形状)前若干:")
        for name, shp in dropped[:8]:
            print("  -", name, shp)
    print("输出文件:", os.path.abspath(args.dst))

if __name__ == "__main__":
    main()