""" Linear interpolation schedule (lerp). """ from typing import Tuple, Union import torch from enum import Enum class LinearSchedule: """ Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow. It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3. x_t = (1 - t) * x_0 + t * x_T """ def __init__(self, T: Union[int, float] = 1.0): self.T = T def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Diffusion forward function. """ t = t[(...,) + (None,) * (x_0.ndim - t.ndim)] if t.ndim < x_0.ndim else t return (1 - t / self.T) * x_0 + (t / self.T) * x_T def convert_from_pred( self, pred: torch.Tensor, pred_type: 'velocity', x_t: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert from velocity prediction. Return predicted x_0 and x_T. """ t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t A_t = 1 - t / self.T B_t = t / self.T # pred_type = 'velocity' pred_x_0 = x_t - B_t * pred pred_x_T = x_t + A_t * pred return pred_x_0, pred_x_T def convert_to_pred( self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: 'velocity' ) -> torch.FloatTensor: """ Convert to velocity prediction target given x_0 and x_T. Predict velocity dx/dt based on the lerp schedule (x_T - x_0). Proposed by rectified flow (https://arxiv.org/abs/2209.03003) """ # pred_type = 'velocity' return x_T - x_0