File size: 5,595 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """Load YAML training config and merge with argparse."""
from __future__ import annotations
import argparse
import os
from pathlib import Path
from typing import Any, Mapping
import yaml
from r3pm_net.paths import REPO_ROOT
def _resolve_maybe_relative(path_str: str | None) -> str | None:
if path_str is None or path_str == "":
return path_str
p = Path(path_str)
if p.is_absolute():
return str(p)
return str(REPO_ROOT / p)
def load_yaml_config(path: str | Path) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if data is None:
return {}
if not isinstance(data, Mapping):
raise ValueError(f"Config must be a mapping, got {type(data)}")
return dict(data)
def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]:
"""Return (config path for YAML, argv without --config ...)."""
path = default_cfg
out: list[str] = []
i = 0
while i < len(argv):
if argv[i] == "--config" and i + 1 < len(argv):
path = argv[i + 1]
i += 2
continue
if argv[i].startswith("--config="):
path = argv[i].split("=", 1)[1]
i += 1
continue
out.append(argv[i])
i += 1
return path, out
def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace:
"""Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI."""
default_cfg = str(REPO_ROOT / "config" / "default.yaml")
cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg)
cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {}
parser = build_parser(cfg_path)
if cfg:
known = {
a.dest
for a in parser._actions
if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS)
}
filtered = {k: v for k, v in cfg.items() if k in known}
parser.set_defaults(**filtered)
return parser.parse_args(argv_rest)
def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None:
"""Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative."""
for key in path_keys:
val = getattr(ns, key, None)
if isinstance(val, str) and val:
setattr(ns, key, _resolve_maybe_relative(val))
def load_eval_yaml() -> dict[str, Any]:
"""Load ``config/eval.yaml`` if present; otherwise return an empty dict."""
path = REPO_ROOT / "config" / "eval.yaml"
if not path.is_file():
return {}
return load_yaml_config(path)
def get_pretrained_rpmnet_dir() -> str:
"""Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc.
``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``.
"""
env = os.environ.get("R3PM_NET_PRETRAINED_ROOT")
if env:
return str(Path(env).expanduser().resolve())
cfg = load_eval_yaml()
rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip()
if not rel:
rel = "checkpoints"
out = _resolve_maybe_relative(rel)
return out if out else str(REPO_ROOT / "checkpoints")
def get_sioux_data_root() -> str:
"""Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.)."""
cfg = load_eval_yaml()
sioux = cfg.get("sioux") or {}
base = sioux.get("base_dir") or cfg.get("data_root") or "data"
out = _resolve_maybe_relative(str(base).strip())
return out if out else str(REPO_ROOT / "data")
def get_modelnet40_paths() -> tuple[str, str]:
"""Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation."""
cfg = load_eval_yaml()
m = cfg.get("modelnet40") or {}
ds = m.get("dataset_path", "data/ModelNet40")
cache = m.get("cache_dir", "data/down_sampled_modelnet40")
dsr = _resolve_maybe_relative(ds)
cr = _resolve_maybe_relative(cache)
return (
dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"),
cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"),
)
def get_method_paths() -> dict[str, Any]:
"""Return resolved path configuration for external registration methods."""
cfg = load_eval_yaml()
methods = cfg.get("methods") or {}
out: dict[str, Any] = {}
for method_name, method_cfg in methods.items():
if not isinstance(method_cfg, Mapping):
continue
method_out: dict[str, Any] = {}
for k, v in method_cfg.items():
if isinstance(v, str) and v.strip():
rv = _resolve_maybe_relative(v.strip())
method_out[k] = rv if rv else v
else:
method_out[k] = v
out[str(method_name)] = method_out
return out
def get_sioux_paths() -> dict[str, Any]:
"""Return Sioux eval paths from config/eval.yaml with absolute paths."""
cfg = load_eval_yaml()
sioux = cfg.get("sioux") or {}
out: dict[str, Any] = {}
for k, v in sioux.items():
if isinstance(v, str) and v.strip():
rv = _resolve_maybe_relative(v.strip())
out[k] = rv if rv else v
elif isinstance(v, list):
vals = []
for item in v:
if isinstance(item, str) and item.strip():
rv = _resolve_maybe_relative(item.strip())
vals.append(rv if rv else item)
else:
vals.append(item)
out[k] = vals
else:
out[k] = v
return out
|