File size: 6,109 Bytes
d382778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
from samplers.general_solver import ODESolver
from samplers.general_solver import update_lists
class DPM_SolverPP(ODESolver):
def __init__(
self,
noise_schedule,
algorithm_type="data_prediction",
):
super().__init__(noise_schedule, algorithm_type)
self.noise_schedule = noise_schedule
def dpm_solver_first_update(self, x, s, t, model_s=None):
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t):
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
phi_1 = torch.expm1(-h)
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t):
ns = self.noise_schedule
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_2),
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
phi_1 = torch.expm1(-h)
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
return x_t
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order):
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True):
x_next = self.multistep_dpm_solver_update(x_next, model_prev_list, t_prev_list, t1, step)
model_x_next = None
if model_x_next is None:
model_x_next = self.model_fn(x_next, t2)
update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first)
return x_next
def sample(
self,
model_fn,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
lower_order_final=True,
flags=None,
):
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
t_0 = self.noise_schedule.eps if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from)
with torch.no_grad():
return self.sample_simple(model_fn, x, order, lower_order_final, timesteps, timesteps2)
def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, lower_order_final=True, condition=None, unconditional_condition=None, **kwargs):
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition)
step = 0
t1 = timesteps[step]
t2 = timesteps2[step]
steps = len(timesteps) - 1
t_prev_list = [t1]
model_prev_list = [self.model_fn(x, t2)]
for step in range(1, order):
t1 = timesteps[step]
t2 = timesteps2[step]
x = self.one_step(t1, t2, t_prev_list, model_prev_list, step, x, order, first=True)
for step in range(order, steps + 1):
t1 = timesteps[step]
t2 = timesteps2[step]
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.one_step(t1, t2, t_prev_list, model_prev_list, step_order, x, order, first=False)
return x
|