| |
| |
| """ |
| 将 20 类语义头的权重预处理为 S3DIS 可用的初始化权重: |
| - 丢弃 seg_head.*(分类头),只保留 backbone 等共享部分 |
| - 可选:安全去掉可能的 "module." 前缀 |
| - 不改动其它 shape 匹配的权重 |
| |
| 用法: |
| python scripts/make_s3dis_init_ckpt.py \ |
| --src exp/default/model/model_last.pth \ |
| --dst exp/default/model/model_last_nohead.pth \ |
| --strip-module |
| """ |
| import argparse, os, sys, torch |
|
|
| def maybe_get_statedict(obj): |
| if isinstance(obj, dict): |
| for k in ("state_dict", "model", "ema_state_dict"): |
| if k in obj and isinstance(obj[k], dict): |
| return obj[k] |
| return obj |
|
|
| def strip_module_prefix(sd): |
| |
| out = {} |
| for k, v in sd.items(): |
| if k.startswith("module."): |
| out[k[len("module."):]] = v |
| else: |
| out[k] = v |
| return out |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--src", required=True, help="源权重路径(20类头)") |
| ap.add_argument("--dst", required=True, help="输出路径(去掉 seg_head 后)") |
| ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀") |
| args = ap.parse_args() |
|
|
| |
| ckpt = torch.load(args.src, map_location="cpu") |
| sd = maybe_get_statedict(ckpt) |
| if not isinstance(sd, dict): |
| print("❌ 未能解析到 state_dict;请检查权重文件。", file=sys.stderr); sys.exit(1) |
|
|
| if args.strip_module: |
| sd = strip_module_prefix(sd) |
|
|
| kept, dropped = {}, [] |
| for k, v in sd.items(): |
| if k.startswith("seg_head."): |
| dropped.append((k, tuple(v.shape))) |
| else: |
| kept[k] = v |
|
|
| os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True) |
| torch.save(kept, args.dst) |
|
|
| print(f"✅ 完成:保留 {len(kept)} 个张量;丢弃 {len(dropped)} 个 seg_head.*") |
| if dropped: |
| print("丢弃清单(名称, 形状)前若干:") |
| for name, shp in dropped[:8]: |
| print(" -", name, shp) |
| print("输出文件:", os.path.abspath(args.dst)) |
|
|
| if __name__ == "__main__": |
| main() |
|
|