| """Load YAML training config and merge with argparse.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| from pathlib import Path |
| from typing import Any, Mapping |
|
|
| import yaml |
|
|
| from r3pm_net.paths import REPO_ROOT |
|
|
|
|
| def _resolve_maybe_relative(path_str: str | None) -> str | None: |
| if path_str is None or path_str == "": |
| return path_str |
| p = Path(path_str) |
| if p.is_absolute(): |
| return str(p) |
| return str(REPO_ROOT / p) |
|
|
|
|
| def load_yaml_config(path: str | Path) -> dict[str, Any]: |
| with open(path, "r", encoding="utf-8") as f: |
| data = yaml.safe_load(f) |
| if data is None: |
| return {} |
| if not isinstance(data, Mapping): |
| raise ValueError(f"Config must be a mapping, got {type(data)}") |
| return dict(data) |
|
|
|
|
| def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]: |
| """Return (config path for YAML, argv without --config ...).""" |
| path = default_cfg |
| out: list[str] = [] |
| i = 0 |
| while i < len(argv): |
| if argv[i] == "--config" and i + 1 < len(argv): |
| path = argv[i + 1] |
| i += 2 |
| continue |
| if argv[i].startswith("--config="): |
| path = argv[i].split("=", 1)[1] |
| i += 1 |
| continue |
| out.append(argv[i]) |
| i += 1 |
| return path, out |
|
|
|
|
| def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace: |
| """Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI.""" |
| default_cfg = str(REPO_ROOT / "config" / "default.yaml") |
| cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg) |
| cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {} |
| parser = build_parser(cfg_path) |
| if cfg: |
| known = { |
| a.dest |
| for a in parser._actions |
| if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS) |
| } |
| filtered = {k: v for k, v in cfg.items() if k in known} |
| parser.set_defaults(**filtered) |
| return parser.parse_args(argv_rest) |
|
|
|
|
| def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None: |
| """Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative.""" |
| for key in path_keys: |
| val = getattr(ns, key, None) |
| if isinstance(val, str) and val: |
| setattr(ns, key, _resolve_maybe_relative(val)) |
|
|
|
|
| def load_eval_yaml() -> dict[str, Any]: |
| """Load ``config/eval.yaml`` if present; otherwise return an empty dict.""" |
| path = REPO_ROOT / "config" / "eval.yaml" |
| if not path.is_file(): |
| return {} |
| return load_yaml_config(path) |
|
|
|
|
| def get_pretrained_rpmnet_dir() -> str: |
| """Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc. |
| |
| ``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``. |
| """ |
| env = os.environ.get("R3PM_NET_PRETRAINED_ROOT") |
| if env: |
| return str(Path(env).expanduser().resolve()) |
| cfg = load_eval_yaml() |
| rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip() |
| if not rel: |
| rel = "checkpoints" |
| out = _resolve_maybe_relative(rel) |
| return out if out else str(REPO_ROOT / "checkpoints") |
|
|
|
|
| def get_sioux_data_root() -> str: |
| """Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.).""" |
| cfg = load_eval_yaml() |
| sioux = cfg.get("sioux") or {} |
| base = sioux.get("base_dir") or cfg.get("data_root") or "data" |
| out = _resolve_maybe_relative(str(base).strip()) |
| return out if out else str(REPO_ROOT / "data") |
|
|
|
|
| def get_modelnet40_paths() -> tuple[str, str]: |
| """Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation.""" |
| cfg = load_eval_yaml() |
| m = cfg.get("modelnet40") or {} |
| ds = m.get("dataset_path", "data/ModelNet40") |
| cache = m.get("cache_dir", "data/down_sampled_modelnet40") |
| dsr = _resolve_maybe_relative(ds) |
| cr = _resolve_maybe_relative(cache) |
| return ( |
| dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"), |
| cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"), |
| ) |
|
|
|
|
| def get_method_paths() -> dict[str, Any]: |
| """Return resolved path configuration for external registration methods.""" |
| cfg = load_eval_yaml() |
| methods = cfg.get("methods") or {} |
| out: dict[str, Any] = {} |
| for method_name, method_cfg in methods.items(): |
| if not isinstance(method_cfg, Mapping): |
| continue |
| method_out: dict[str, Any] = {} |
| for k, v in method_cfg.items(): |
| if isinstance(v, str) and v.strip(): |
| rv = _resolve_maybe_relative(v.strip()) |
| method_out[k] = rv if rv else v |
| else: |
| method_out[k] = v |
| out[str(method_name)] = method_out |
| return out |
|
|
|
|
| def get_sioux_paths() -> dict[str, Any]: |
| """Return Sioux eval paths from config/eval.yaml with absolute paths.""" |
| cfg = load_eval_yaml() |
| sioux = cfg.get("sioux") or {} |
| out: dict[str, Any] = {} |
| for k, v in sioux.items(): |
| if isinstance(v, str) and v.strip(): |
| rv = _resolve_maybe_relative(v.strip()) |
| out[k] = rv if rv else v |
| elif isinstance(v, list): |
| vals = [] |
| for item in v: |
| if isinstance(item, str) and item.strip(): |
| rv = _resolve_maybe_relative(item.strip()) |
| vals.append(rv if rv else item) |
| else: |
| vals.append(item) |
| out[k] = vals |
| else: |
| out[k] = v |
| return out |
|
|