File size: 5,595 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Load YAML training config and merge with argparse."""

from __future__ import annotations

import argparse
import os
from pathlib import Path
from typing import Any, Mapping

import yaml

from r3pm_net.paths import REPO_ROOT


def _resolve_maybe_relative(path_str: str | None) -> str | None:
    if path_str is None or path_str == "":
        return path_str
    p = Path(path_str)
    if p.is_absolute():
        return str(p)
    return str(REPO_ROOT / p)


def load_yaml_config(path: str | Path) -> dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    if data is None:
        return {}
    if not isinstance(data, Mapping):
        raise ValueError(f"Config must be a mapping, got {type(data)}")
    return dict(data)


def _extract_config_argv(argv: list[str], default_cfg: str) -> tuple[str, list[str]]:
    """Return (config path for YAML, argv without --config ...)."""
    path = default_cfg
    out: list[str] = []
    i = 0
    while i < len(argv):
        if argv[i] == "--config" and i + 1 < len(argv):
            path = argv[i + 1]
            i += 2
            continue
        if argv[i].startswith("--config="):
            path = argv[i].split("=", 1)[1]
            i += 1
            continue
        out.append(argv[i])
        i += 1
    return path, out


def parse_train_args(argv: list[str], build_parser) -> argparse.Namespace:
    """Load YAML from --config (default: config/default.yaml), merge as argparse defaults, then parse CLI."""
    default_cfg = str(REPO_ROOT / "config" / "default.yaml")
    cfg_path, argv_rest = _extract_config_argv(list(argv), default_cfg)
    cfg = load_yaml_config(cfg_path) if Path(cfg_path).is_file() else {}
    parser = build_parser(cfg_path)
    if cfg:
        known = {
            a.dest
            for a in parser._actions
            if getattr(a, "dest", None) and a.dest not in ("help", argparse.SUPPRESS)
        }
        filtered = {k: v for k, v in cfg.items() if k in known}
        parser.set_defaults(**filtered)
    return parser.parse_args(argv_rest)


def resolve_path_args(ns: Any, path_keys: tuple[str, ...]) -> None:
    """Mutate namespace: resolve listed keys to absolute paths under REPO_ROOT when relative."""
    for key in path_keys:
        val = getattr(ns, key, None)
        if isinstance(val, str) and val:
            setattr(ns, key, _resolve_maybe_relative(val))


def load_eval_yaml() -> dict[str, Any]:
    """Load ``config/eval.yaml`` if present; otherwise return an empty dict."""
    path = REPO_ROOT / "config" / "eval.yaml"
    if not path.is_file():
        return {}
    return load_yaml_config(path)


def get_pretrained_rpmnet_dir() -> str:
    """Directory containing ``clean-trained.pth``, ``best_model_PointNet*.t7``, etc.

    ``R3PM_NET_PRETRAINED_ROOT`` overrides ``pretrained_rpmnet_dir`` in ``config/eval.yaml``.
    """
    env = os.environ.get("R3PM_NET_PRETRAINED_ROOT")
    if env:
        return str(Path(env).expanduser().resolve())
    cfg = load_eval_yaml()
    rel = (cfg.get("pretrained_rpmnet_dir") or "checkpoints").strip()
    if not rel:
        rel = "checkpoints"
    out = _resolve_maybe_relative(rel)
    return out if out else str(REPO_ROOT / "checkpoints")


def get_sioux_data_root() -> str:
    """Base data directory for Sioux scripts (``data`` / ``sioux_cranfield``, etc.)."""
    cfg = load_eval_yaml()
    sioux = cfg.get("sioux") or {}
    base = sioux.get("base_dir") or cfg.get("data_root") or "data"
    out = _resolve_maybe_relative(str(base).strip())
    return out if out else str(REPO_ROOT / "data")


def get_modelnet40_paths() -> tuple[str, str]:
    """Return ``(dataset_path, cache_dir)`` for ModelNet40 evaluation."""
    cfg = load_eval_yaml()
    m = cfg.get("modelnet40") or {}
    ds = m.get("dataset_path", "data/ModelNet40")
    cache = m.get("cache_dir", "data/down_sampled_modelnet40")
    dsr = _resolve_maybe_relative(ds)
    cr = _resolve_maybe_relative(cache)
    return (
        dsr if dsr else str(REPO_ROOT / "data" / "ModelNet40"),
        cr if cr else str(REPO_ROOT / "data" / "down_sampled_modelnet40"),
    )


def get_method_paths() -> dict[str, Any]:
    """Return resolved path configuration for external registration methods."""
    cfg = load_eval_yaml()
    methods = cfg.get("methods") or {}
    out: dict[str, Any] = {}
    for method_name, method_cfg in methods.items():
        if not isinstance(method_cfg, Mapping):
            continue
        method_out: dict[str, Any] = {}
        for k, v in method_cfg.items():
            if isinstance(v, str) and v.strip():
                rv = _resolve_maybe_relative(v.strip())
                method_out[k] = rv if rv else v
            else:
                method_out[k] = v
        out[str(method_name)] = method_out
    return out


def get_sioux_paths() -> dict[str, Any]:
    """Return Sioux eval paths from config/eval.yaml with absolute paths."""
    cfg = load_eval_yaml()
    sioux = cfg.get("sioux") or {}
    out: dict[str, Any] = {}
    for k, v in sioux.items():
        if isinstance(v, str) and v.strip():
            rv = _resolve_maybe_relative(v.strip())
            out[k] = rv if rv else v
        elif isinstance(v, list):
            vals = []
            for item in v:
                if isinstance(item, str) and item.strip():
                    rv = _resolve_maybe_relative(item.strip())
                    vals.append(rv if rv else item)
                else:
                    vals.append(item)
            out[k] = vals
        else:
            out[k] = v
    return out