#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 把各种形状的权重(纯state_dict / checkpoint / 带module前缀 / 含ema_state_dict等) 统一转换成 Pointcept test.py 期望的格式:{"state_dict": Tensor>} 顺带: - 如果存在 seg_head.* 且与你的 num_classes 不一致,直接丢弃(迁移到S3DIS常见) - 自动去掉 "module." 前缀 """ import os, argparse, torch def pick_statedict(obj): # 1) checkpoint风格 if isinstance(obj, dict): for k in ("state_dict", "model", "ema_state_dict"): if k in obj and isinstance(obj[k], dict): return obj[k] # 2) 纯字典(直接是name->tensor) if all(hasattr(v, "shape") for v in obj.values()): return obj raise KeyError("找不到可用的 state_dict(既无 'state_dict'/'model'/'ema_state_dict',也不是纯字典)") def strip_module(sd): out = {} for k, v in sd.items(): out[k[7:]] = v if k.startswith("module.") else v return out def drop_seg_head_if_mismatch(sd, expected_num_classes=13, head_name="seg_head"): w = sd.get(f"{head_name}.weight", None) b = sd.get(f"{head_name}.bias", None) if w is not None and hasattr(w, "shape"): out_features = w.shape[0] if out_features != expected_num_classes: sd.pop(f"{head_name}.weight", None) sd.pop(f"{head_name}.bias", None) return sd def main(): ap = argparse.ArgumentParser() ap.add_argument("--src", required=True, help="源权重路径(任意格式)") ap.add_argument("--dst", required=True, help="输出checkpoint路径(含 state_dict 键)") ap.add_argument("--num-classes", type=int, default=13, help="期望的类别数(S3DIS=13)") ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀") args = ap.parse_args() ckpt = torch.load(args.src, map_location="cpu") sd = pick_statedict(ckpt) if args.strip_module: sd = strip_module(sd) sd = drop_seg_head_if_mismatch(sd, expected_num_classes=args.num_classes) # 最终统一存为:顶层包含 'state_dict' os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True) torch.save({"state_dict": sd, "epoch": -1}, args.dst) print("✅ Wrote unified checkpoint to:", os.path.abspath(args.dst)) # 简要回显 w = sd.get("seg_head.weight", None) print("seg_head.weight:", None if w is None else tuple(w.shape)) if __name__ == "__main__": main()