File size: 2,601 Bytes
5686f5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """Shared helpers for source-process configuration in flow and diffusion models."""
from __future__ import annotations
from typing import Callable, Tuple
import torch
from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
from sim_priors_pk.models.diffusion.noise import GaussianProcess, GaussianProcessRegression, Normal, OrnsteinUhlenbeck, Wiener
def normalize_source_type(source_type: str) -> str:
"""Normalize a source-process name to a canonical lowercase form."""
return source_type.strip().lower().replace("-", "_").replace(" ", "_")
def build_source_process(
source_cfg: SourceProcessConfig | None,
*,
dim: int = 1,
) -> Tuple[Callable, bool]:
"""
Build the source process for flow/diffusion models.
Returns
-------
process:
Callable noise/source process instance.
is_time_series:
Whether the process induces temporal covariance and requires `t`.
"""
if source_cfg is None:
source_cfg = SourceProcessConfig()
source_type = normalize_source_type(str(source_cfg.source_type))
gp_variance = float(source_cfg.gp_variance)
gp_length_scale = float(source_cfg.gp_length_scale)
gp_eps = float(source_cfg.gp_eps)
gp_transform=source_cfg.gp_transform
if source_type in ("gaussian_process", "gp"):
return GaussianProcess(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True
if source_type in ("gaussian_process_regression", "gp_regression"):
return GaussianProcessRegression(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True
if source_type in ("ornstein_uhlenbeck", "ou"):
return OrnsteinUhlenbeck(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps), True
if source_type in ("wiener", "brownian"):
return Wiener(dim=dim), False
if source_type in ("normal", "gaussian", "white_noise"):
return Normal(dim=dim), False
raise ValueError(f"Unsupported source_type: {source_type!r}")
def sample_source_process(
source_process: Callable,
t: torch.Tensor,
*,
device: torch.device | None = None,
is_time_series: bool = True,
) -> torch.Tensor:
"""Sample a source process given the target time grid."""
if is_time_series:
sample = source_process(t=t, device=device)
else:
sample = source_process(*t.shape[:-1])
if torch.is_tensor(sample) and device is not None:
sample = sample.to(device)
return sample
|