"""Load YAML hyperparameters and resolve filesystem paths relative to the repository root.""" from __future__ import annotations import os from pathlib import Path from typing import Any import yaml from flare.definitions import REPO_DIR # Keys that may hold filesystem paths (relative paths are resolved against REPO_DIR). _PATH_KEYS = frozenset( { "dataset_pth", "candidates_pth", "subformula_dir_pth", "split_pth", "checkpoint_pth", "df_test_path", "formula_to_smiles_pth", } ) def resolve_repo_paths(params: dict[str, Any]) -> None: """In-place: turn repo-relative path strings into absolute paths.""" root = REPO_DIR for key in _PATH_KEYS: val = params.get(key) if not val or not isinstance(val, str): continue p = Path(val) if not p.is_absolute(): params[key] = str((root / p).resolve()) else: params[key] = str(p.resolve()) def load_param_file(path: str | Path) -> dict[str, Any]: """Load a YAML parameter file and resolve path fields.""" p = Path(path) if not p.is_file(): raise FileNotFoundError(f"Parameter file not found: {p}") with open(p, encoding="utf-8") as f: params = yaml.load(f, Loader=yaml.FullLoader) if params is None: params = {} if not isinstance(params, dict): raise TypeError(f"Expected mapping at top level of {p}, got {type(params)}") resolve_repo_paths(params) return params def default_param_path() -> Path: """Path to the default params file (overridable with FLARE_PARAMS).""" override = os.environ.get("FLARE_PARAMS") if override: return Path(override).expanduser() env_root = os.environ.get("FLARE_REPO_ROOT") if env_root: return Path(env_root).expanduser() / "params.yaml" return REPO_DIR / "params.yaml"