File size: 5,123 Bytes
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""

from __future__ import annotations

import math

import torch
from torch import Tensor


def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
    """Compute (alpha, sigma) from logSNR in float32.

    VP constraint: alpha^2 + sigma^2 = 1.
    """
    lmb32 = lmb.to(dtype=torch.float32)
    alpha = torch.sqrt(torch.sigmoid(lmb32))
    sigma = torch.sqrt(torch.sigmoid(-lmb32))
    return alpha, sigma


def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
    """Broadcast [B] coefficient to match x for per-sample scaling."""
    view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
    return coeff.view(view_shape)


def _cosine_interpolated_params(
    logsnr_min: float, logsnr_max: float
) -> tuple[float, float]:
    """Compute (a, b) for cosine-interpolated logSNR schedule.

    logsnr(t) = -2 * log(tan(a*t + b))
    logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
    """
    b = math.atan(math.exp(-0.5 * logsnr_max))
    a = math.atan(math.exp(-0.5 * logsnr_min)) - b
    return a, b


def cosine_interpolated_logsnr_from_t(
    t: Tensor, *, logsnr_min: float, logsnr_max: float
) -> Tensor:
    """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
    a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
    t32 = t.to(dtype=torch.float32)
    a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
    b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
    u = a_t * t32 + b_t
    return -2.0 * torch.log(torch.tan(u))


def shifted_cosine_interpolated_logsnr_from_t(
    t: Tensor,
    *,
    logsnr_min: float,
    logsnr_max: float,
    log_change_high: float = 0.0,
    log_change_low: float = 0.0,
) -> Tensor:
    """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.

    lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
    """
    base = cosine_interpolated_logsnr_from_t(
        t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
    )
    t32 = t.to(dtype=torch.float32)
    high = base + float(log_change_high)
    low = base + float(log_change_low)
    return (1.0 - t32) * high + t32 * low


def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
    """Generate a descending t-schedule in [0, 1] for VP diffusion sampling.

    ``num_steps`` is the number of function evaluations (NFE = decoder forward
    passes).  Internally the schedule has ``num_steps + 1`` time points
    (including both endpoints).

    Args:
        schedule_type: "linear" or "cosine".
        num_steps: Number of decoder forward passes (NFE), >= 1.

    Returns:
        Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
    """
    # NOTE: the upstream training code (src/ode/time_schedules.py) uses a
    # different convention where num_steps counts schedule *points* (so NFE =
    # num_steps - 1).  This export package corrects the off-by-one so that
    # num_steps means NFE directly.  TODO: align the upstream convention.
    n = max(int(num_steps) + 1, 2)
    if schedule_type == "linear":
        base = torch.linspace(0.0, 1.0, n)
    elif schedule_type == "cosine":
        i = torch.arange(n, dtype=torch.float32)
        base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
    else:
        raise ValueError(
            f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
        )
    # Descending: high t (noisy) -> low t (clean)
    return torch.flip(base, dims=[0])


def make_initial_state(
    *,
    noise: Tensor,
    t_start: Tensor,
    logsnr_min: float,
    logsnr_max: float,
    log_change_high: float = 0.0,
    log_change_low: float = 0.0,
) -> Tensor:
    """Construct VP initial state x_t0 = sigma_start * noise (since x0=0).

    All math in float32.
    """
    batch = int(noise.shape[0])
    lmb_start = shifted_cosine_interpolated_logsnr_from_t(
        t_start.expand(batch).to(dtype=torch.float32),
        logsnr_min=logsnr_min,
        logsnr_max=logsnr_max,
        log_change_high=log_change_high,
        log_change_low=log_change_low,
    )
    _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
    sigma_view = broadcast_time_like(sigma_start, noise)
    return sigma_view * noise.to(dtype=torch.float32)


def sample_noise(
    shape: tuple[int, ...],
    *,
    noise_std: float = 1.0,
    seed: int | None = None,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
) -> Tensor:
    """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
    if seed is None:
        noise = torch.randn(
            shape, device=device or torch.device("cpu"), dtype=torch.float32
        )
    else:
        gen = torch.Generator(device="cpu")
        gen.manual_seed(int(seed))
        noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
    noise = noise.mul(float(noise_std))
    target_device = device if device is not None else torch.device("cpu")
    return noise.to(device=target_device, dtype=dtype)