LD3 / samplers /dpm_solver.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
from samplers.general_solver import ODESolver
class DPM_Solver(ODESolver):
def __init__(
self,
noise_schedule,
algorithm_type="noise_prediction", # need to be noise prediction!
):
super().__init__(noise_schedule, algorithm_type)
self.noise_schedule = noise_schedule
def compute_K_and_order(self, steps, order):
assert order in [1, 2]
if order == 1:
K = steps
orders = [1,] * steps
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2,] * K
else:
K = steps // 2 + 1
orders = [2,] * (K - 1) + [1]
return K, orders
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, flags, device):
'''
steps: NFEs
'''
# order == 1 means DDIM (DPM-Solver-1)
# order == 2 means DPM-Solver-2
K, orders = self.compute_K_and_order(steps, order)
timesteps_outer, timesteps_outer2, rs, rs2 = self.prepare_timesteps_single(steps=K, NFEs=steps, t_start=t_T, t_end=t_0, flags=flags, device=device, skip_type=skip_type)
return timesteps_outer, timesteps_outer2, rs, rs2, orders
def dpm_solver_first_update(self, x, s1, s2, t1, model_s=None):
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s1), ns.marginal_lambda(t1)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t1)
sigma_t = ns.marginal_std(t1)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s2) # noise prediction!
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
)
return x_t
def dpm_solver_second_update(self, x, s1, s2, t1, r1=0.5, r2=0.5, model_s=None):
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s1), ns.marginal_lambda(t1)
h = lambda_t - lambda_s
lambda_s_inter1 = lambda_s + r1 * h
lambda_s_inter2 = lambda_s + r2 * h
s_inter1 = ns.inverse_lambda(lambda_s_inter1)
s_inter2 = ns.inverse_lambda(lambda_s_inter2)
log_alpha_s, log_alpha_s_inter, log_alpha_t = ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s_inter1), ns.marginal_log_mean_coeff(t1)
sigma_s_inter, sigma_t = ns.marginal_std(s_inter1), ns.marginal_std(t1)
phi_1_inter = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s2)
x_s_inter = (
torch.exp(log_alpha_s_inter - log_alpha_s) * x
- (sigma_s_inter * phi_1_inter) * model_s
)
model_s_inter = self.model_fn(x_s_inter, s_inter2)
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (0.5 / r1) * (sigma_t * phi_1) * (model_s_inter - model_s)
)
return x_t
def singlestep_dpm_solver_update(self, x, s1, s2, t1, order, r1=0.5, r2=0.5, model_s=None):
if order == 1:
x_t = self.dpm_solver_first_update(x, s1, s2, t1, model_s=model_s)
elif order == 2:
x_t = self.dpm_solver_second_update(x, s1, s2, t1, r1, r2, model_s=model_s)
else:
raise ValueError("Order must be 1 or 2.")
return x_t
def sample(
self,
model_fn,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
flags=None,
):
# check if order is 2
assert order == 2
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
timesteps, timesteps2, rs, rs2, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps, order, skip_type, t_T, t_0, flags, device)
with torch.no_grad():
return self.sample_simple(model_fn, x, orders, timesteps, timesteps2, rs, rs2)
def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, rs=None, rs2=None, condition=None, unconditional_condition=None, **kwargs):
'''
order is a list of order
'''
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition)
if rs is None:
rs = [0.5,] * len(timesteps)
if rs2 is None:
rs2 = [0.5,] * len(timesteps)
orders = order
for step, od in enumerate(orders):
s1, t1 = timesteps[step], timesteps[step + 1]
s2 = timesteps2[step]
r1, r2 = rs[step], rs2[step]
x = self.singlestep_dpm_solver_update(x, s1, s2, t1, od, r1, r2)
return x