File size: 1,795 Bytes
5686f5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import os
from dataclasses import dataclass
from typing import Optional, Union
try: # pragma: no cover - exercised indirectly via configuration loading
import yaml # type: ignore
except ModuleNotFoundError: # pragma: no cover - fallback for minimal environments
from sim_priors_pk.config_classes import yaml_fallback as yaml
@dataclass
class SourceProcessConfig:
"""
Configuration for source processes used by flow and diffusion PK models.
Supported source_type values (case-insensitive):
- "gaussian_process" / "gp"
- "ornstein_uhlenbeck" / "ou"
- "wiener"
- "normal" / "gaussian"
"""
source_type: str = "gaussian_process"
# Gaussian process hyper-parameter for RBF or OU.
gp_length_scale: float = 0.1
gp_variance: float = 1.0
gp_eps: float = 1e-8
gp_transform: str = 'softplus' # transformation to apply to the sampled noise, e.g. 'softplus', 'exp'
# Flow matching additive noise scale (used only in FlowPK).
flow_sigma: float = 1e-4
flow_num_steps: int = 100
use_OT_coupling: bool = False
@classmethod
def from_yaml(cls, file_path: Union[str, os.PathLike]) -> "SourceProcessConfig":
"""Instantiate the source-process configuration from a YAML file."""
with open(file_path, "r", encoding="utf-8") as handle:
config_dict = yaml.safe_load(handle) or {}
if isinstance(config_dict, dict):
for key in ("source_process", "source", "noise_model"):
if key in config_dict:
config_dict = config_dict.get(key) or {}
break
if not isinstance(config_dict, dict):
raise TypeError("Expected source process configuration to be a mapping.")
return cls(**config_dict)
|