| |
| |
| """ |
| 把各种形状的权重(纯state_dict / checkpoint / 带module前缀 / 含ema_state_dict等) |
| 统一转换成 Pointcept test.py 期望的格式:{"state_dict": <name->Tensor>} |
| |
| 顺带: |
| - 如果存在 seg_head.* 且与你的 num_classes 不一致,直接丢弃(迁移到S3DIS常见) |
| - 自动去掉 "module." 前缀 |
| """ |
|
|
| import os, argparse, torch |
|
|
| def pick_statedict(obj): |
| |
| if isinstance(obj, dict): |
| for k in ("state_dict", "model", "ema_state_dict"): |
| if k in obj and isinstance(obj[k], dict): |
| return obj[k] |
| |
| if all(hasattr(v, "shape") for v in obj.values()): |
| return obj |
| raise KeyError("找不到可用的 state_dict(既无 'state_dict'/'model'/'ema_state_dict',也不是纯字典)") |
|
|
| def strip_module(sd): |
| out = {} |
| for k, v in sd.items(): |
| out[k[7:]] = v if k.startswith("module.") else v |
| return out |
|
|
| def drop_seg_head_if_mismatch(sd, expected_num_classes=13, head_name="seg_head"): |
| w = sd.get(f"{head_name}.weight", None) |
| b = sd.get(f"{head_name}.bias", None) |
| if w is not None and hasattr(w, "shape"): |
| out_features = w.shape[0] |
| if out_features != expected_num_classes: |
| sd.pop(f"{head_name}.weight", None) |
| sd.pop(f"{head_name}.bias", None) |
| return sd |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--src", required=True, help="源权重路径(任意格式)") |
| ap.add_argument("--dst", required=True, help="输出checkpoint路径(含 state_dict 键)") |
| ap.add_argument("--num-classes", type=int, default=13, help="期望的类别数(S3DIS=13)") |
| ap.add_argument("--strip-module", action="store_true", help="去除 'module.' 前缀") |
| args = ap.parse_args() |
|
|
| ckpt = torch.load(args.src, map_location="cpu") |
| sd = pick_statedict(ckpt) |
| if args.strip_module: |
| sd = strip_module(sd) |
| sd = drop_seg_head_if_mismatch(sd, expected_num_classes=args.num_classes) |
|
|
| |
| os.makedirs(os.path.dirname(args.dst) or ".", exist_ok=True) |
| torch.save({"state_dict": sd, "epoch": -1}, args.dst) |
| print("✅ Wrote unified checkpoint to:", os.path.abspath(args.dst)) |
| |
| w = sd.get("seg_head.weight", None) |
| print("seg_head.weight:", None if w is None else tuple(w.shape)) |
|
|
| if __name__ == "__main__": |
| main() |
|
|