File size: 2,326 Bytes
3499c27 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
将 20 类语义头的权重预处理为 S3DIS 可用的初始化权重:
- 丢弃 seg_head.*(分类头),只保留 backbone 等共享部分
- 可选:安全去掉可能的 "module." 前缀
- 不改动其它 shape 匹配的权重
用法:
python scripts/make_s3dis_init_ckpt.py \
--src exp/default/model/model_last.pth \
--dst exp/default/model/model_last_nohead.pth \
--strip-module
"""
import argparse, os, sys, torch
def maybe_get_statedict(obj):
if isinstance(obj, dict):
for k in ("state_dict", "model", "ema_state_dict"):
if k in obj and isinstance(obj[k], dict):
return obj[k]
return obj
def strip_module_prefix(sd):
# 把 "module." 前缀去掉,避免 DDP 保存的前缀影响加载
out = {}
for k, v in sd.items():
if k.startswith("module."):
out[k[len("module."):]] = v
else:
out[k] = v
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--src", required=True, help="源权重路径(20类头)")
ap.add_argument("--dst", required=True, help="输出路径(去掉 seg_head 后)")
ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀")
args = ap.parse_args()
# 新版 PyTorch 建议 weights_only=True;但为兼容历史文件,先保持默认
ckpt = torch.load(args.src, map_location="cpu")
sd = maybe_get_statedict(ckpt)
if not isinstance(sd, dict):
print("❌ 未能解析到 state_dict;请检查权重文件。", file=sys.stderr); sys.exit(1)
if args.strip_module:
sd = strip_module_prefix(sd)
kept, dropped = {}, []
for k, v in sd.items():
if k.startswith("seg_head."):
dropped.append((k, tuple(v.shape)))
else:
kept[k] = v
os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True)
torch.save(kept, args.dst)
print(f"✅ 完成:保留 {len(kept)} 个张量;丢弃 {len(dropped)} 个 seg_head.*")
if dropped:
print("丢弃清单(名称, 形状)前若干:")
for name, shp in dropped[:8]:
print(" -", name, shp)
print("输出文件:", os.path.abspath(args.dst))
if __name__ == "__main__":
main()
|