#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 将 20 类语义头的权重预处理为 S3DIS 可用的初始化权重: - 丢弃 seg_head.*(分类头),只保留 backbone 等共享部分 - 可选:安全去掉可能的 "module." 前缀 - 不改动其它 shape 匹配的权重 用法: python scripts/make_s3dis_init_ckpt.py \ --src exp/default/model/model_last.pth \ --dst exp/default/model/model_last_nohead.pth \ --strip-module """ import argparse, os, sys, torch def maybe_get_statedict(obj): if isinstance(obj, dict): for k in ("state_dict", "model", "ema_state_dict"): if k in obj and isinstance(obj[k], dict): return obj[k] return obj def strip_module_prefix(sd): # 把 "module." 前缀去掉,避免 DDP 保存的前缀影响加载 out = {} for k, v in sd.items(): if k.startswith("module."): out[k[len("module."):]] = v else: out[k] = v return out def main(): ap = argparse.ArgumentParser() ap.add_argument("--src", required=True, help="源权重路径(20类头)") ap.add_argument("--dst", required=True, help="输出路径(去掉 seg_head 后)") ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀") args = ap.parse_args() # 新版 PyTorch 建议 weights_only=True;但为兼容历史文件,先保持默认 ckpt = torch.load(args.src, map_location="cpu") sd = maybe_get_statedict(ckpt) if not isinstance(sd, dict): print("❌ 未能解析到 state_dict;请检查权重文件。", file=sys.stderr); sys.exit(1) if args.strip_module: sd = strip_module_prefix(sd) kept, dropped = {}, [] for k, v in sd.items(): if k.startswith("seg_head."): dropped.append((k, tuple(v.shape))) else: kept[k] = v os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True) torch.save(kept, args.dst) print(f"✅ 完成:保留 {len(kept)} 个张量;丢弃 {len(dropped)} 个 seg_head.*") if dropped: print("丢弃清单(名称, 形状)前若干:") for name, shp in dropped[:8]: print(" -", name, shp) print("输出文件:", os.path.abspath(args.dst)) if __name__ == "__main__": main()