File size: 2,601 Bytes
5686f5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Shared helpers for source-process configuration in flow and diffusion models."""

from __future__ import annotations

from typing import Callable, Tuple

import torch

from sim_priors_pk.config_classes.source_process_config import SourceProcessConfig
from sim_priors_pk.models.diffusion.noise import GaussianProcess, GaussianProcessRegression, Normal, OrnsteinUhlenbeck, Wiener

def normalize_source_type(source_type: str) -> str:
    """Normalize a source-process name to a canonical lowercase form."""
    return source_type.strip().lower().replace("-", "_").replace(" ", "_")

def build_source_process(
    source_cfg: SourceProcessConfig | None,
    *,
    dim: int = 1,
) -> Tuple[Callable, bool]:
    """
    Build the source process for flow/diffusion models.

    Returns
    -------
    process:
        Callable noise/source process instance.
    is_time_series:
        Whether the process induces temporal covariance and requires `t`.
    """
    if source_cfg is None:
        source_cfg = SourceProcessConfig()

    source_type = normalize_source_type(str(source_cfg.source_type))
    gp_variance = float(source_cfg.gp_variance)
    gp_length_scale = float(source_cfg.gp_length_scale)
    gp_eps = float(source_cfg.gp_eps)
    gp_transform=source_cfg.gp_transform

    if source_type in ("gaussian_process", "gp"):
        return GaussianProcess(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True

    if source_type in ("gaussian_process_regression", "gp_regression"):
        return GaussianProcessRegression(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps, transform=gp_transform), True

    if source_type in ("ornstein_uhlenbeck", "ou"):
        return OrnsteinUhlenbeck(dim=dim, variance=gp_variance, length_scale=gp_length_scale, epsilon=gp_eps), True

    if source_type in ("wiener", "brownian"):
        return Wiener(dim=dim), False
         
    if source_type in ("normal", "gaussian", "white_noise"):
        return Normal(dim=dim), False

    raise ValueError(f"Unsupported source_type: {source_type!r}")

def sample_source_process(
    source_process: Callable,
    t: torch.Tensor,
    *,
    device: torch.device | None = None,
    is_time_series: bool = True,
) -> torch.Tensor:
    """Sample a source process given the target time grid."""
    if is_time_series:
        sample = source_process(t=t, device=device)
    else:
        sample = source_process(*t.shape[:-1])
    if torch.is_tensor(sample) and device is not None:
        sample = sample.to(device)
    return sample