cesarali's picture
manual runtime bundle push from load_and_push.ipynb
5686f5b verified
"""Shared helpers for source-process configuration in flow and diffusion models."""
from __future__ import annotations
from typing import Callable, Tuple
import torch
from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
from sim_priors_pk.models.diffusion.noise import GaussianProcess, GaussianProcessRegression, Normal, OrnsteinUhlenbeck, Wiener
def normalize_source_type(source_type: str) -> str:
"""Normalize a source-process name to a canonical lowercase form."""
return source_type.strip().lower().replace("-", "_").replace(" ", "_")
def build_source_process(
source_cfg: SourceProcessConfig | None,
*,
dim: int = 1,
) -> Tuple[Callable, bool]:
"""
Build the source process for flow/diffusion models.
Returns
-------
process:
Callable noise/source process instance.
is_time_series:
Whether the process induces temporal covariance and requires `t`.
"""
if source_cfg is None:
source_cfg = SourceProcessConfig()
source_type = normalize_source_type(str(source_cfg.source_type))
gp_variance = float(source_cfg.gp_variance)
gp_length_scale = float(source_cfg.gp_length_scale)
gp_eps = float(source_cfg.gp_eps)
gp_transform=source_cfg.gp_transform
if source_type in ("gaussian_process", "gp"):
return GaussianProcess(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True
if source_type in ("gaussian_process_regression", "gp_regression"):
return GaussianProcessRegression(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True
if source_type in ("ornstein_uhlenbeck", "ou"):
return OrnsteinUhlenbeck(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps), True
if source_type in ("wiener", "brownian"):
return Wiener(dim=dim), False
if source_type in ("normal", "gaussian", "white_noise"):
return Normal(dim=dim), False
raise ValueError(f"Unsupported source_type: {source_type!r}")
def sample_source_process(
source_process: Callable,
t: torch.Tensor,
*,
device: torch.device | None = None,
is_time_series: bool = True,
) -> torch.Tensor:
"""Sample a source process given the target time grid."""
if is_time_series:
sample = source_process(t=t, device=device)
else:
sample = source_process(*t.shape[:-1])
if torch.is_tensor(sample) and device is not None:
sample = sample.to(device)
return sample