File size: 2,513 Bytes
3499c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
把各种形状的权重(纯state_dict / checkpoint / 带module前缀 / 含ema_state_dict等)
统一转换成 Pointcept test.py 期望的格式:{"state_dict": <name->Tensor>}

顺带:
- 如果存在 seg_head.* 且与你的 num_classes 不一致,直接丢弃(迁移到S3DIS常见)
- 自动去掉 "module." 前缀
"""

import os, argparse, torch

def pick_statedict(obj):
    # 1) checkpoint风格
    if isinstance(obj, dict):
        for k in ("state_dict", "model", "ema_state_dict"):
            if k in obj and isinstance(obj[k], dict):
                return obj[k]
        # 2) 纯字典(直接是name->tensor)
        if all(hasattr(v, "shape") for v in obj.values()):
            return obj
    raise KeyError("找不到可用的 state_dict(既无 'state_dict'/'model'/'ema_state_dict',也不是纯字典)")

def strip_module(sd):
    out = {}
    for k, v in sd.items():
        out[k[7:]] = v if k.startswith("module.") else v
    return out

def drop_seg_head_if_mismatch(sd, expected_num_classes=13, head_name="seg_head"):
    w = sd.get(f"{head_name}.weight", None)
    b = sd.get(f"{head_name}.bias", None)
    if w is not None and hasattr(w, "shape"):
        out_features = w.shape[0]
        if out_features != expected_num_classes:
            sd.pop(f"{head_name}.weight", None)
            sd.pop(f"{head_name}.bias", None)
    return sd

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--src", required=True, help="源权重路径(任意格式)")
    ap.add_argument("--dst", required=True, help="输出checkpoint路径(含 state_dict 键)")
    ap.add_argument("--num-classes", type=int, default=13, help="期望的类别数(S3DIS=13)")
    ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀")
    args = ap.parse_args()

    ckpt = torch.load(args.src, map_location="cpu")
    sd = pick_statedict(ckpt)
    if args.strip_module:
        sd = strip_module(sd)
    sd = drop_seg_head_if_mismatch(sd, expected_num_classes=args.num_classes)

    # 最终统一存为:顶层包含 'state_dict'
    os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True)
    torch.save({"state_dict": sd, "epoch": -1}, args.dst)
    print("✅ Wrote unified checkpoint to:", os.path.abspath(args.dst))
    # 简要回显
    w = sd.get("seg_head.weight", None)
    print("seg_head.weight:", None if w is None else tuple(w.shape))

if __name__ == "__main__":
    main()