Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,736 Bytes
709cfd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
"""
Linear interpolation schedule (lerp).
"""
from typing import Tuple, Union
import torch
from enum import Enum
class LinearSchedule:
"""
Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow.
It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3.
x_t = (1 - t) * x_0 + t * x_T
"""
def __init__(self, T: Union[int, float] = 1.0):
self.T = T
def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Diffusion forward function.
"""
t = t[(...,) + (None,) * (x_0.ndim - t.ndim)] if t.ndim < x_0.ndim else t
return (1 - t / self.T) * x_0 + (t / self.T) * x_T
def convert_from_pred(
self, pred: torch.Tensor, pred_type: 'velocity', x_t: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert from velocity prediction. Return predicted x_0 and x_T.
"""
t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t
A_t = 1 - t / self.T
B_t = t / self.T
# pred_type = 'velocity'
pred_x_0 = x_t - B_t * pred
pred_x_T = x_t + A_t * pred
return pred_x_0, pred_x_T
def convert_to_pred(
self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: 'velocity'
) -> torch.FloatTensor:
"""
Convert to velocity prediction target given x_0 and x_T.
Predict velocity dx/dt based on the lerp schedule (x_T - x_0).
Proposed by rectified flow (https://arxiv.org/abs/2209.03003)
"""
# pred_type = 'velocity'
return x_T - x_0
|