| """Shared helpers for source-process configuration in flow and diffusion models.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Callable, Tuple |
|
|
| import torch |
|
|
| from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig |
| from sim_priors_pk.models.diffusion.noise import GaussianProcess, GaussianProcessRegression, Normal, OrnsteinUhlenbeck, Wiener |
|
|
| def normalize_source_type(source_type: str) -> str: |
| """Normalize a source-process name to a canonical lowercase form.""" |
| return source_type.strip().lower().replace("-", "_").replace(" ", "_") |
|
|
| def build_source_process( |
| source_cfg: SourceProcessConfig | None, |
| *, |
| dim: int = 1, |
| ) -> Tuple[Callable, bool]: |
| """ |
| Build the source process for flow/diffusion models. |
| |
| Returns |
| ------- |
| process: |
| Callable noise/source process instance. |
| is_time_series: |
| Whether the process induces temporal covariance and requires `t`. |
| """ |
| if source_cfg is None: |
| source_cfg = SourceProcessConfig() |
|
|
| source_type = normalize_source_type(str(source_cfg.source_type)) |
| gp_variance = float(source_cfg.gp_variance) |
| gp_length_scale = float(source_cfg.gp_length_scale) |
| gp_eps = float(source_cfg.gp_eps) |
| gp_transform=source_cfg.gp_transform |
|
|
| if source_type in ("gaussian_process", "gp"): |
| return GaussianProcess(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True |
|
|
| if source_type in ("gaussian_process_regression", "gp_regression"): |
| return GaussianProcessRegression(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True |
|
|
| if source_type in ("ornstein_uhlenbeck", "ou"): |
| return OrnsteinUhlenbeck(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps), True |
|
|
| if source_type in ("wiener", "brownian"): |
| return Wiener(dim=dim), False |
| |
| if source_type in ("normal", "gaussian", "white_noise"): |
| return Normal(dim=dim), False |
|
|
| raise ValueError(f"Unsupported source_type: {source_type!r}") |
|
|
| def sample_source_process( |
| source_process: Callable, |
| t: torch.Tensor, |
| *, |
| device: torch.device | None = None, |
| is_time_series: bool = True, |
| ) -> torch.Tensor: |
| """Sample a source process given the target time grid.""" |
| if is_time_series: |
| sample = source_process(t=t, device=device) |
| else: |
| sample = source_process(*t.shape[:-1]) |
| if torch.is_tensor(sample) and device is not None: |
| sample = sample.to(device) |
| return sample |
|
|