|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
A general framework for various sampling algorithm from a diffusion model. |
|
|
Impl based on |
|
|
* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 |
|
|
* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. |
|
|
Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ |
|
|
adding new step function in get_runge_kutta_fn or get_multi_step_fn. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union |
|
|
|
|
|
import attrs |
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported |
|
|
from cosmos_predict1.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported |
|
|
from cosmos_predict1.utils.config import make_freezable |
|
|
|
|
|
COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class SolverConfig: |
|
|
is_multi: bool = False |
|
|
rk: str = "2mid" |
|
|
multistep: str = "2ab" |
|
|
|
|
|
|
|
|
s_churn: float = 0.0 |
|
|
s_t_max: float = float("inf") |
|
|
s_t_min: float = 0.05 |
|
|
s_noise: float = 1.0 |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class SolverTimestampConfig: |
|
|
nfe: int = 50 |
|
|
t_min: float = 0.002 |
|
|
t_max: float = 80.0 |
|
|
order: float = 7.0 |
|
|
is_forward: bool = False |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class SamplerConfig: |
|
|
solver: SolverConfig = attrs.field(factory=SolverConfig) |
|
|
timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) |
|
|
sample_clean: bool = True |
|
|
|
|
|
|
|
|
def get_rev_ts( |
|
|
t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate a sequence of reverse time steps. |
|
|
|
|
|
Args: |
|
|
t_min (float): The minimum time value. |
|
|
t_max (float): The maximum time value. |
|
|
num_steps (int): The number of time steps to generate. |
|
|
ts_order (Union[int, float]): The order of the time step progression. |
|
|
is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A tensor containing the generated time steps in reverse or forward order. |
|
|
|
|
|
Raises: |
|
|
ValueError: If `t_min` is not less than `t_max`. |
|
|
TypeError: If `ts_order` is not an integer or float. |
|
|
""" |
|
|
if t_min >= t_max: |
|
|
raise ValueError("t_min must be less than t_max") |
|
|
|
|
|
if not isinstance(ts_order, (int, float)): |
|
|
raise TypeError("ts_order must be an integer or float") |
|
|
|
|
|
step_indices = torch.arange(num_steps + 1, dtype=torch.float64) |
|
|
time_steps = ( |
|
|
t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) |
|
|
) ** ts_order |
|
|
|
|
|
if is_forward: |
|
|
return time_steps.flip(dims=(0,)) |
|
|
|
|
|
return time_steps |
|
|
|
|
|
|
|
|
class Sampler(torch.nn.Module): |
|
|
def __init__(self, cfg: Optional[SamplerConfig] = None): |
|
|
super().__init__() |
|
|
if cfg is None: |
|
|
cfg = SamplerConfig() |
|
|
self.cfg = cfg |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward( |
|
|
self, |
|
|
x0_fn: Callable, |
|
|
x_sigma_max: torch.Tensor, |
|
|
num_steps: int = 35, |
|
|
sigma_min: float = 0.002, |
|
|
sigma_max: float = 80, |
|
|
rho: float = 7, |
|
|
S_churn: float = 0, |
|
|
S_min: float = 0, |
|
|
S_max: float = float("inf"), |
|
|
S_noise: float = 1, |
|
|
solver_option: str = "2ab", |
|
|
) -> torch.Tensor: |
|
|
in_dtype = x_sigma_max.dtype |
|
|
|
|
|
def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: |
|
|
return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) |
|
|
|
|
|
is_multistep = is_multi_step_fn_supported(solver_option) |
|
|
is_rk = is_runge_kutta_fn_supported(solver_option) |
|
|
assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" |
|
|
|
|
|
solver_cfg = SolverConfig( |
|
|
s_churn=S_churn, |
|
|
s_t_max=S_max, |
|
|
s_t_min=S_min, |
|
|
s_noise=S_noise, |
|
|
is_multi=is_multistep, |
|
|
rk=solver_option, |
|
|
multistep=solver_option, |
|
|
) |
|
|
timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) |
|
|
sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) |
|
|
|
|
|
return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _forward_impl( |
|
|
self, |
|
|
denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
|
|
noisy_input_B_StateShape: torch.Tensor, |
|
|
sampler_cfg: Optional[SamplerConfig] = None, |
|
|
callback_fns: Optional[List[Callable]] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Internal implementation of the forward pass. |
|
|
|
|
|
Args: |
|
|
denoiser_fn: Function to denoise the input. |
|
|
noisy_input_B_StateShape: Input tensor with noise. |
|
|
sampler_cfg: Configuration for the sampler. |
|
|
callback_fns: List of callback functions to be called during sampling. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Denoised output tensor. |
|
|
""" |
|
|
sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg |
|
|
solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) |
|
|
num_timestamps = sampler_cfg.timestamps.nfe // solver_order |
|
|
|
|
|
sigmas_L = get_rev_ts( |
|
|
sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order |
|
|
).to(noisy_input_B_StateShape.device) |
|
|
|
|
|
denoised_output = differential_equation_solver( |
|
|
denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns |
|
|
)(noisy_input_B_StateShape) |
|
|
|
|
|
if sampler_cfg.sample_clean: |
|
|
|
|
|
ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) |
|
|
denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) |
|
|
|
|
|
return denoised_output |
|
|
|
|
|
|
|
|
def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: |
|
|
""" |
|
|
Implements a for loop with a function. |
|
|
|
|
|
Args: |
|
|
lower: Lower bound of the loop (inclusive). |
|
|
upper: Upper bound of the loop (exclusive). |
|
|
body_fun: Function to be applied in each iteration. |
|
|
init_val: Initial value for the loop. |
|
|
|
|
|
Returns: |
|
|
The final result after all iterations. |
|
|
""" |
|
|
val = init_val |
|
|
for i in range(lower, upper): |
|
|
val = body_fun(i, val) |
|
|
return val |
|
|
|
|
|
|
|
|
def differential_equation_solver( |
|
|
x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
|
|
sigmas_L: torch.Tensor, |
|
|
solver_cfg: SolverConfig, |
|
|
callback_fns: Optional[List[Callable]] = None, |
|
|
) -> Callable[[torch.Tensor], torch.Tensor]: |
|
|
""" |
|
|
Creates a differential equation solver function. |
|
|
|
|
|
Args: |
|
|
x0_fn: Function to compute x0 prediction. |
|
|
sigmas_L: Tensor of sigma values with shape [L,]. |
|
|
solver_cfg: Configuration for the solver. |
|
|
callback_fns: Optional list of callback functions. |
|
|
|
|
|
Returns: |
|
|
A function that solves the differential equation. |
|
|
""" |
|
|
num_step = len(sigmas_L) - 1 |
|
|
|
|
|
if solver_cfg.is_multi: |
|
|
update_step_fn = get_multi_step_fn(solver_cfg.multistep) |
|
|
else: |
|
|
update_step_fn = get_runge_kutta_fn(solver_cfg.rk) |
|
|
|
|
|
eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) |
|
|
|
|
|
def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Samples from the differential equation. |
|
|
|
|
|
Args: |
|
|
input_xT_B_StateShape: Input tensor with shape [B, StateShape]. |
|
|
|
|
|
Returns: |
|
|
Output tensor with shape [B, StateShape]. |
|
|
""" |
|
|
ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) |
|
|
|
|
|
def step_fn( |
|
|
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] |
|
|
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
|
|
input_x_B_StateShape, x0_preds = state |
|
|
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] |
|
|
|
|
|
|
|
|
if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: |
|
|
hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 |
|
|
input_x_B_StateShape = input_x_B_StateShape + ( |
|
|
hat_sigma_cur_0**2 - sigma_cur_0**2 |
|
|
).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) |
|
|
sigma_cur_0 = hat_sigma_cur_0 |
|
|
|
|
|
if solver_cfg.is_multi: |
|
|
x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) |
|
|
output_x_B_StateShape, x0_preds = update_step_fn( |
|
|
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds |
|
|
) |
|
|
else: |
|
|
output_x_B_StateShape, x0_preds = update_step_fn( |
|
|
input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn |
|
|
) |
|
|
|
|
|
if callback_fns: |
|
|
for callback_fn in callback_fns: |
|
|
callback_fn(**locals()) |
|
|
|
|
|
return output_x_B_StateShape, x0_preds |
|
|
|
|
|
x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) |
|
|
return x_at_eps |
|
|
|
|
|
return sample_fn |
|
|
|