File size: 5,622 Bytes
62bb9d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
# Codebase ref: https://github.com/scxue/SA-Solver

import math
from typing import Union, Callable
import torch


def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
    """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.

    Integral of exp((1 + tau^2) * x) * x^p dx
        = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
    with base case p=0 where integral equals product_terms[0].

    where
        product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).

    Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
    Return coefficients used by the SA-Solver in data prediction mode.

    Args:
        s: Start time s.
        t: End time t.
        solver_order: Current order of the solver.
        tau_t: Stochastic strength parameter in the SDE.

    Returns:
        Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
    """
    tau_mul = 1 + tau_t ** 2
    h = t - s
    p = torch.arange(solver_order, dtype=s.dtype, device=s.device)

    # product_terms after factoring out exp((1 + tau^2) * t)
    # Includes (1 + tau^2) factor from outside the integral
    product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())

    # Lower triangular recursive coefficient matrix
    # Accumulates recursive coefficients based on p / (1 + tau^2)
    recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
    log_factorial = (p + 1).lgamma()
    recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
    if tau_t > 0:
        recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
    signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
    recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()

    return recursive_coeff_mat @ product_terms_factored


def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
    """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
    tau_mul = 1 + tau_t ** 2
    h = lambda_t - lambda_s
    alpha_t = sigma_next * lambda_t.exp()
    if is_corrector_step:
        # Simplified 1-step (order-2) corrector
        b_1 = alpha_t * (0.5 * tau_mul * h)
        b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
    else:
        # Simplified 2-step predictor
        b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
        b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
    return torch.stack([b_2, b_1])


def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
    """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).

    The solver order corresponds to the number of input lambdas (half-logSNR points).

    Args:
        sigma_next: Sigma at end time t.
        curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
        lambda_s: Lambda at start time s.
        lambda_t: Lambda at end time t.
        tau_t: Stochastic strength parameter in the SDE.
        simple_order_2: Whether to enable the simple order-2 scheme.
        is_corrector_step: Flag for corrector step in simple order-2 mode.

    Returns:
        b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
    """
    num_timesteps = curr_lambdas.shape[0]

    if simple_order_2 and num_timesteps == 2:
        return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)

    # Compute coefficients by solving a linear system from Lagrange basis interpolation
    exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
    vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
    lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)

    # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
    # = sigma_t * exp(lambda_t) = alpha_t
    # exp((1 + tau^2) * lambda_t) is extracted from the integral
    alpha_t = sigma_next * lambda_t.exp()
    return alpha_t * lagrange_integrals


def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
    """Return a function that controls the stochasticity of SA-Solver.

    When eta = 0, SA-Solver runs as ODE. The official approach uses
    time t to determine the SDE interval, while here we use sigma instead.

    See:
        https://github.com/scxue/SA-Solver/blob/main/README.md
    """

    def tau_func(sigma: Union[torch.Tensor, float]) -> float:
        if eta <= 0:
            return 0.0  # ODE

        if isinstance(sigma, torch.Tensor):
            sigma = sigma.item()
        return eta if start_sigma >= sigma >= end_sigma else 0.0

    return tau_func