File size: 2,513 Bytes
3499c27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
把各种形状的权重(纯state_dict / checkpoint / 带module前缀 / 含ema_state_dict等)
统一转换成 Pointcept test.py 期望的格式:{"state_dict": <name->Tensor>}
顺带:
- 如果存在 seg_head.* 且与你的 num_classes 不一致,直接丢弃(迁移到S3DIS常见)
- 自动去掉 "module." 前缀
"""
import os, argparse, torch
def pick_statedict(obj):
# 1) checkpoint风格
if isinstance(obj, dict):
for k in ("state_dict", "model", "ema_state_dict"):
if k in obj and isinstance(obj[k], dict):
return obj[k]
# 2) 纯字典(直接是name->tensor)
if all(hasattr(v, "shape") for v in obj.values()):
return obj
raise KeyError("找不到可用的 state_dict(既无 'state_dict'/'model'/'ema_state_dict',也不是纯字典)")
def strip_module(sd):
out = {}
for k, v in sd.items():
out[k[7:]] = v if k.startswith("module.") else v
return out
def drop_seg_head_if_mismatch(sd, expected_num_classes=13, head_name="seg_head"):
w = sd.get(f"{head_name}.weight", None)
b = sd.get(f"{head_name}.bias", None)
if w is not None and hasattr(w, "shape"):
out_features = w.shape[0]
if out_features != expected_num_classes:
sd.pop(f"{head_name}.weight", None)
sd.pop(f"{head_name}.bias", None)
return sd
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--src", required=True, help="源权重路径(任意格式)")
ap.add_argument("--dst", required=True, help="输出checkpoint路径(含 state_dict 键)")
ap.add_argument("--num-classes", type=int, default=13, help="期望的类别数(S3DIS=13)")
ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀")
args = ap.parse_args()
ckpt = torch.load(args.src, map_location="cpu")
sd = pick_statedict(ckpt)
if args.strip_module:
sd = strip_module(sd)
sd = drop_seg_head_if_mismatch(sd, expected_num_classes=args.num_classes)
# 最终统一存为:顶层包含 'state_dict'
os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True)
torch.save({"state_dict": sd, "epoch": -1}, args.dst)
print("✅ Wrote unified checkpoint to:", os.path.abspath(args.dst))
# 简要回显
w = sd.get("seg_head.weight", None)
print("seg_head.weight:", None if w is None else tuple(w.shape))
if __name__ == "__main__":
main()
|