import torch from samplers.general_solver import ODESolver from samplers.general_solver import update_lists class DPM_SolverPP(ODESolver): def __init__( self, noise_schedule, algorithm_type="data_prediction", ): super().__init__(noise_schedule, algorithm_type) self.noise_schedule = noise_schedule def dpm_solver_first_update(self, x, s, t, model_s=None): ns = self.noise_schedule lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) h = lambda_t - lambda_s log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s return x_t def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t): ns = self.noise_schedule model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] lambda_prev_1, lambda_prev_0, lambda_t = ( ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t), ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) phi_1 = torch.expm1(-h) x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 return x_t def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t): ns = self.noise_schedule model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t), ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) h_1 = lambda_prev_1 - lambda_prev_2 h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) phi_1 = torch.expm1(-h) phi_2 = phi_1 / h + 1.0 phi_3 = phi_2 / h - 0.5 x_t = ( (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 + (alpha_t * phi_2) * D1 - (alpha_t * phi_3) * D2 ) return x_t def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order): if order == 1: return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) elif order == 2: return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t) elif order == 3: return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) def one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order, first=True): x_next = self.multistep_dpm_solver_update(x_next, model_prev_list, t_prev_list, t1, step) model_x_next = None if model_x_next is None: model_x_next = self.model_fn(x_next, t2) update_lists(t_prev_list, model_prev_list, t1, model_x_next, order, first=first) return x_next def sample( self, model_fn, x, steps=20, t_start=None, t_end=None, order=2, skip_type="time_uniform", lower_order_final=True, flags=None, ): self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) t_0 = self.noise_schedule.eps if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start device = x.device timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from) with torch.no_grad(): return self.sample_simple(model_fn, x, order, lower_order_final, timesteps, timesteps2) def sample_simple(self, model_fn, x, timesteps, timesteps2, order=2, lower_order_final=True, condition=None, unconditional_condition=None, **kwargs): self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])), condition, unconditional_condition) step = 0 t1 = timesteps[step] t2 = timesteps2[step] steps = len(timesteps) - 1 t_prev_list = [t1] model_prev_list = [self.model_fn(x, t2)] for step in range(1, order): t1 = timesteps[step] t2 = timesteps2[step] x = self.one_step(t1, t2, t_prev_list, model_prev_list, step, x, order, first=True) for step in range(order, steps + 1): t1 = timesteps[step] t2 = timesteps2[step] if lower_order_final: step_order = min(order, steps + 1 - step) else: step_order = order x = self.one_step(t1, t2, t_prev_list, model_prev_list, step_order, x, order, first=False) return x