R3PM-Net / r3pm_net /config_loader.py
YasiiKB's picture
initial commit
97aa5af verified
"""Load YAML training config and merge with argparse."""
from __future__ import annotations
import argparse
import os
from pathlib import Path
from typing import Any, Mapping
import yaml
from r3pm_net.paths import REPO_ROOT
def _resolve_maybe_relative(path_str: str | None) -> str | None:
if path_str is None or path_str == "":
return path_str
p = Path(path_str)
if p.is_absolute():
return str(p)
return str(REPO_ROOT / p)
def load_yaml_config(path: str | Path) -> dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if data is None:
return {}
if not isinstance(data, Mapping):
raise ValueError(f"Config must be a mapping, got {type(data)}")
return dict(data)
def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]:
"""Return (config path for YAML, argv without --config ...)."""
path = default_cfg
out: list[str] = []
i = 0
while i < len(argv):
if argv[i] == "--config" and i + 1 < len(argv):
path = argv[i + 1]
i += 2
continue
if argv[i].startswith("--config="):
path = argv[i].split("=", 1)[1]
i += 1
continue
out.append(argv[i])
i += 1
return path, out
def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace:
"""Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI."""
default_cfg = str(REPO_ROOT / "config" / "default.yaml")
cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg)
cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {}
parser = build_parser(cfg_path)
if cfg:
known = {
a.dest
for a in parser._actions
if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS)
}
filtered = {k: v for k, v in cfg.items() if k in known}
parser.set_defaults(**filtered)
return parser.parse_args(argv_rest)
def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None:
"""Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative."""
for key in path_keys:
val = getattr(ns, key, None)
if isinstance(val, str) and val:
setattr(ns, key, _resolve_maybe_relative(val))
def load_eval_yaml() -> dict[str, Any]:
"""Load ``config/eval.yaml`` if present; otherwise return an empty dict."""
path = REPO_ROOT / "config" / "eval.yaml"
if not path.is_file():
return {}
return load_yaml_config(path)
def get_pretrained_rpmnet_dir() -> str:
"""Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc.
``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``.
"""
env = os.environ.get("R3PM_NET_PRETRAINED_ROOT")
if env:
return str(Path(env).expanduser().resolve())
cfg = load_eval_yaml()
rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip()
if not rel:
rel = "checkpoints"
out = _resolve_maybe_relative(rel)
return out if out else str(REPO_ROOT / "checkpoints")
def get_sioux_data_root() -> str:
"""Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.)."""
cfg = load_eval_yaml()
sioux = cfg.get("sioux") or {}
base = sioux.get("base_dir") or cfg.get("data_root") or "data"
out = _resolve_maybe_relative(str(base).strip())
return out if out else str(REPO_ROOT / "data")
def get_modelnet40_paths() -> tuple[str, str]:
"""Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation."""
cfg = load_eval_yaml()
m = cfg.get("modelnet40") or {}
ds = m.get("dataset_path", "data/ModelNet40")
cache = m.get("cache_dir", "data/down_sampled_modelnet40")
dsr = _resolve_maybe_relative(ds)
cr = _resolve_maybe_relative(cache)
return (
dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"),
cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"),
)
def get_method_paths() -> dict[str, Any]:
"""Return resolved path configuration for external registration methods."""
cfg = load_eval_yaml()
methods = cfg.get("methods") or {}
out: dict[str, Any] = {}
for method_name, method_cfg in methods.items():
if not isinstance(method_cfg, Mapping):
continue
method_out: dict[str, Any] = {}
for k, v in method_cfg.items():
if isinstance(v, str) and v.strip():
rv = _resolve_maybe_relative(v.strip())
method_out[k] = rv if rv else v
else:
method_out[k] = v
out[str(method_name)] = method_out
return out
def get_sioux_paths() -> dict[str, Any]:
"""Return Sioux eval paths from config/eval.yaml with absolute paths."""
cfg = load_eval_yaml()
sioux = cfg.get("sioux") or {}
out: dict[str, Any] = {}
for k, v in sioux.items():
if isinstance(v, str) and v.strip():
rv = _resolve_maybe_relative(v.strip())
out[k] = rv if rv else v
elif isinstance(v, list):
vals = []
for item in v:
if isinstance(item, str) and item.strip():
rv = _resolve_maybe_relative(item.strip())
vals.append(rv if rv else item)
else:
vals.append(item)
out[k] = vals
else:
out[k] = v
return out