FLARE / flare /utils /config.py
yzhouchen001's picture
clean up
6c3d8a1
"""Load YAML hyperparameters and resolve filesystem paths relative to the repository root."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import yaml
from flare.definitions import REPO_DIR
# Keys that may hold filesystem paths (relative paths are resolved against REPO_DIR).
_PATH_KEYS = frozenset(
{
"dataset_pth",
"candidates_pth",
"subformula_dir_pth",
"split_pth",
"checkpoint_pth",
"df_test_path",
"formula_to_smiles_pth",
}
)
def resolve_repo_paths(params: dict[str, Any]) -> None:
"""In-place: turn repo-relative path strings into absolute paths."""
root = REPO_DIR
for key in _PATH_KEYS:
val = params.get(key)
if not val or not isinstance(val, str):
continue
p = Path(val)
if not p.is_absolute():
params[key] = str((root / p).resolve())
else:
params[key] = str(p.resolve())
def load_param_file(path: str | Path) -> dict[str, Any]:
"""Load a YAML parameter file and resolve path fields."""
p = Path(path)
if not p.is_file():
raise FileNotFoundError(f"Parameter file not found: {p}")
with open(p, encoding="utf-8") as f:
params = yaml.load(f, Loader=yaml.FullLoader)
if params is None:
params = {}
if not isinstance(params, dict):
raise TypeError(f"Expected mapping at top level of {p}, got {type(params)}")
resolve_repo_paths(params)
return params
def default_param_path() -> Path:
"""Path to the default params file (overridable with FLARE_PARAMS)."""
override = os.environ.get("FLARE_PARAMS")
if override:
return Path(override).expanduser()
env_root = os.environ.get("FLARE_REPO_ROOT")
if env_root:
return Path(env_root).expanduser() / "params.yaml"
return REPO_DIR / "params.yaml"