Spaces:
Running
Running
| """Load YAML hyperparameters and resolve filesystem paths relative to the repository root.""" | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| from flare.definitions import REPO_DIR | |
| # Keys that may hold filesystem paths (relative paths are resolved against REPO_DIR). | |
| _PATH_KEYS = frozenset( | |
| { | |
| "dataset_pth", | |
| "candidates_pth", | |
| "subformula_dir_pth", | |
| "split_pth", | |
| "checkpoint_pth", | |
| "df_test_path", | |
| "formula_to_smiles_pth", | |
| } | |
| ) | |
| def resolve_repo_paths(params: dict[str, Any]) -> None: | |
| """In-place: turn repo-relative path strings into absolute paths.""" | |
| root = REPO_DIR | |
| for key in _PATH_KEYS: | |
| val = params.get(key) | |
| if not val or not isinstance(val, str): | |
| continue | |
| p = Path(val) | |
| if not p.is_absolute(): | |
| params[key] = str((root / p).resolve()) | |
| else: | |
| params[key] = str(p.resolve()) | |
| def load_param_file(path: str | Path) -> dict[str, Any]: | |
| """Load a YAML parameter file and resolve path fields.""" | |
| p = Path(path) | |
| if not p.is_file(): | |
| raise FileNotFoundError(f"Parameter file not found: {p}") | |
| with open(p, encoding="utf-8") as f: | |
| params = yaml.load(f, Loader=yaml.FullLoader) | |
| if params is None: | |
| params = {} | |
| if not isinstance(params, dict): | |
| raise TypeError(f"Expected mapping at top level of {p}, got {type(params)}") | |
| resolve_repo_paths(params) | |
| return params | |
| def default_param_path() -> Path: | |
| """Path to the default params file (overridable with FLARE_PARAMS).""" | |
| override = os.environ.get("FLARE_PARAMS") | |
| if override: | |
| return Path(override).expanduser() | |
| env_root = os.environ.get("FLARE_REPO_ROOT") | |
| if env_root: | |
| return Path(env_root).expanduser() / "params.yaml" | |
| return REPO_DIR / "params.yaml" | |