|
|
import torch |
|
|
from abc import ABC, abstractmethod |
|
|
import os |
|
|
from noise_schedulers import NoiseScheduleVE, NoiseScheduleVP |
|
|
|
|
|
import numpy as np |
|
|
from scipy.optimize import minimize |
|
|
from scipy.optimize import LinearConstraint |
|
|
|
|
|
class StepOptim(object): |
|
|
def __init__(self, ns): |
|
|
super().__init__() |
|
|
self.ns = ns |
|
|
self.T = self.ns.T |
|
|
self.is_latent_space = isinstance(self.ns, NoiseScheduleVP) |
|
|
|
|
|
def alpha(self, t): |
|
|
t = torch.as_tensor(t, dtype = torch.float64) |
|
|
return self.ns.marginal_alpha(t).numpy() |
|
|
def sigma(self, t): |
|
|
return np.sqrt(1 - self.alpha(t) * self.alpha(t)) |
|
|
def lambda_func(self, t): |
|
|
return np.log(self.alpha(t)/self.sigma(t)) |
|
|
def edm_lambda_func(self, t): |
|
|
return np.log(self.alpha(t)/self.edm_sigma(t)) |
|
|
def H0(self, h): |
|
|
return np.exp(h) - 1 |
|
|
def H1(self, h): |
|
|
return np.exp(h) * h - self.H0(h) |
|
|
def H2(self, h): |
|
|
return np.exp(h) * h * h - 2 * self.H1(h) |
|
|
def H3(self, h): |
|
|
return np.exp(h) * h * h * h - 3 * self.H2(h) |
|
|
def inverse_lambda(self, lamb): |
|
|
lamb = torch.as_tensor(lamb, dtype = torch.float64) |
|
|
return self.ns.inverse_lambda(lamb) |
|
|
def edm_sigma(self, t): |
|
|
return t |
|
|
def edm_inverse_sigma(self, edm_sigma): |
|
|
alpha = 1 / (edm_sigma*edm_sigma+1).sqrt() |
|
|
sigma = alpha*edm_sigma |
|
|
lambda_t = np.log(alpha/sigma) |
|
|
t = self.inverse_lambda(lambda_t) |
|
|
return t |
|
|
|
|
|
def sel_lambdas_lof_obj(self, lambda_vec, eps): |
|
|
lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func |
|
|
lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item() |
|
|
lambda_vec_ext = np.concatenate((np.array([lambda_T]), lambda_vec, np.array([lambda_eps]))) |
|
|
N = len(lambda_vec_ext) - 1 |
|
|
|
|
|
hv = np.zeros(N) |
|
|
for i in range(N): |
|
|
hv[i] = lambda_vec_ext[i+1] - lambda_vec_ext[i] |
|
|
elv = np.exp(lambda_vec_ext) |
|
|
emlv_sq = np.exp(-2*lambda_vec_ext) |
|
|
alpha_vec = 1./np.sqrt(1+emlv_sq) |
|
|
sigma_vec = 1./np.sqrt(1+np.exp(2*lambda_vec_ext)) |
|
|
if self.is_latent_space: |
|
|
data_err_vec = (sigma_vec**2)/alpha_vec |
|
|
else: |
|
|
data_err_vec = (sigma_vec**1)/alpha_vec |
|
|
|
|
|
|
|
|
if N <= 7: |
|
|
truncNum = 3 |
|
|
else: |
|
|
truncNum = 0 |
|
|
|
|
|
res = 0. |
|
|
c_vec = np.zeros(N) |
|
|
for s in range(N): |
|
|
if s in [0, N-1]: |
|
|
n, kp = s, 1 |
|
|
J_n_kp_0 = elv[n+1] - elv[n] |
|
|
res += abs(J_n_kp_0 * data_err_vec[n]) |
|
|
elif s in [1, N-2]: |
|
|
n, kp = s-1, 2 |
|
|
J_n_kp_0 = -elv[n+1] * self.H1(hv[n+1]) / hv[n] |
|
|
J_n_kp_1 = elv[n+1] * (self.H1(hv[n+1])+hv[n]*self.H0(hv[n+1])) / hv[n] |
|
|
if s >= truncNum: |
|
|
c_vec[n] += data_err_vec[n] * J_n_kp_0 |
|
|
c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1 |
|
|
else: |
|
|
res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2) |
|
|
else: |
|
|
n, kp = s-2, 3 |
|
|
J_n_kp_0 = elv[n+2] * (self.H2(hv[n+2])+hv[n+1]*self.H1(hv[n+2])) / (hv[n]*(hv[n]+hv[n+1])) |
|
|
J_n_kp_1 = -elv[n+2] * (self.H2(hv[n+2])+(hv[n]+hv[n+1])*self.H1(hv[n+2])) / (hv[n]*hv[n+1]) |
|
|
J_n_kp_2 = elv[n+2] * (self.H2(hv[n+2])+(2*hv[n+1]+hv[n])*self.H1(hv[n+2])+hv[n+1]*(hv[n]+hv[n+1])*self.H0(hv[n+2])) / (hv[n+1]*(hv[n]+hv[n+1])) |
|
|
if s >= truncNum: |
|
|
c_vec[n] += data_err_vec[n] * J_n_kp_0 |
|
|
c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1 |
|
|
c_vec[n+2] += data_err_vec[n+2] * J_n_kp_2 |
|
|
else: |
|
|
res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2 + (data_err_vec[n+2] * J_n_kp_2)**2) |
|
|
res += sum(abs(c_vec)) |
|
|
return res |
|
|
|
|
|
def get_ts_lambdas(self, N, eps): |
|
|
if self.is_latent_space: |
|
|
initType = "unif_t" |
|
|
else: |
|
|
initType = "unif" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func |
|
|
lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item() |
|
|
|
|
|
|
|
|
constr_mat = np.zeros((N, N-1)) |
|
|
for i in range(N-1): |
|
|
constr_mat[i][i] = 1. |
|
|
constr_mat[i+1][i] = -1 |
|
|
lb_vec = np.zeros(N) |
|
|
lb_vec[0], lb_vec[-1] = lambda_T, -lambda_eps |
|
|
|
|
|
ub_vec = np.zeros(N) |
|
|
for i in range(N): |
|
|
ub_vec[i] = np.inf |
|
|
linear_constraint = LinearConstraint(constr_mat, lb_vec, ub_vec) |
|
|
|
|
|
|
|
|
if initType in ['unif', 'unif_origin']: |
|
|
lambda_vec_ext = torch.linspace(lambda_T, lambda_eps, N+1) |
|
|
elif initType in ['unif_t', 'unif_t_origin']: |
|
|
t_vec = torch.linspace(self.T, eps, N+1) |
|
|
lambda_vec_ext = self.lambda_func(t_vec) |
|
|
elif initType in ['edm', 'edm_origin']: |
|
|
rho = 7 |
|
|
edm_sigma_min, edm_sigma_max = self.edm_sigma(eps).item(), self.edm_sigma(self.T).item() |
|
|
edm_sigma_vec = torch.linspace(edm_sigma_max**(1. / rho), edm_sigma_min**(1. / rho), N + 1).pow(rho) |
|
|
t_vec = self.edm_inverse_sigma(edm_sigma_vec) |
|
|
lambda_vec_ext = self.lambda_func(t_vec) |
|
|
elif initType in ['quad', 'quad_origin']: |
|
|
t_order = 2 |
|
|
t_vec = torch.linspace(self.T**(1./t_order), eps**(1./t_order), N+1).pow(t_order) |
|
|
lambda_vec_ext = self.lambda_func(t_vec) |
|
|
else: |
|
|
print('InitType not found!') |
|
|
return |
|
|
|
|
|
if initType in ['unif_origin', 'unif_t_origin', 'edm_origin', 'quad_origin']: |
|
|
lambda_res = lambda_vec_ext |
|
|
t_res = torch.tensor(self.inverse_lambda(lambda_res)) |
|
|
else: |
|
|
lambda_vec_init = np.array(lambda_vec_ext[1:-1]) |
|
|
res = minimize(self.sel_lambdas_lof_obj, lambda_vec_init, method='trust-constr', args=(eps), constraints=[linear_constraint], options={'verbose': 1}) |
|
|
lambda_res = torch.tensor(np.concatenate((np.array([lambda_T]), res.x, np.array([lambda_eps])))) |
|
|
t_res = torch.tensor(self.inverse_lambda(lambda_res)) |
|
|
return t_res, lambda_res |
|
|
|
|
|
|
|
|
|
|
|
def expand_dims(x, dims): |
|
|
for _ in range(dims): |
|
|
x = x.unsqueeze(-1) |
|
|
return x |
|
|
|
|
|
def update_lists(t_list, model_list, t_, model_x, order, first=False): |
|
|
if first: |
|
|
t_list.append(t_) |
|
|
model_list.append(model_x) |
|
|
return |
|
|
for m in range(order - 1): |
|
|
t_list[m] = t_list[m + 1] |
|
|
model_list[m] = model_list[m + 1] |
|
|
t_list[-1] = t_ |
|
|
model_list[-1] = model_x |
|
|
|
|
|
class ODESolver(ABC): |
|
|
def __init__( |
|
|
self, |
|
|
noise_schedule, |
|
|
algorithm_type="data_prediction", |
|
|
correcting_x0_fn=None, |
|
|
): |
|
|
self.noise_schedule = noise_schedule |
|
|
assert algorithm_type in ["data_prediction", "noise_prediction"] |
|
|
self.predict_x0 = algorithm_type == "data_prediction" |
|
|
self.correcting_x0_fn = correcting_x0_fn |
|
|
|
|
|
|
|
|
def dx_dt_for_blackbox_solvers(self, x, t1, t2): |
|
|
''' |
|
|
for edm, dx_dt = noise |
|
|
''' |
|
|
ft = self.noise_schedule.ft(t1) |
|
|
gt = self.noise_schedule.gt(t1) |
|
|
sigma_t = self.noise_schedule.marginal_std(t1) |
|
|
noise = self.noise_prediction_fn(x, t2) |
|
|
return ft * x + gt ** 2 / (2 * sigma_t) * noise |
|
|
|
|
|
|
|
|
def noise_prediction_fn(self, x, t): |
|
|
""" |
|
|
Return the noise prediction model. |
|
|
""" |
|
|
return self.model(x, t) |
|
|
|
|
|
def data_prediction_fn(self, x, t): |
|
|
""" |
|
|
Return the data prediction model (with corrector). |
|
|
""" |
|
|
noise = self.noise_prediction_fn(x, t) |
|
|
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) |
|
|
x0 = (x - sigma_t * noise) / alpha_t |
|
|
if self.correcting_x0_fn is not None: |
|
|
x0 = self.correcting_x0_fn(x0) |
|
|
return x0 |
|
|
|
|
|
def model_fn(self, x, t): |
|
|
""" |
|
|
Convert the model to the noise prediction model or the data prediction model. |
|
|
""" |
|
|
if self.predict_x0: |
|
|
return self.data_prediction_fn(x, t) |
|
|
else: |
|
|
return self.noise_prediction_fn(x, t) |
|
|
|
|
|
def get_time_steps(self, skip_type, t_T, t_0, N, device): |
|
|
"""Compute the intermediate time steps for sampling. |
|
|
""" |
|
|
if skip_type == 'logSNR': |
|
|
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) |
|
|
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) |
|
|
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) |
|
|
t = self.noise_schedule.inverse_lambda(logSNR_steps) |
|
|
elif skip_type == 'time_uniform': |
|
|
t = torch.linspace(t_T, t_0, N + 1).to(device) |
|
|
elif skip_type == 'time_quadratic': |
|
|
rho = 2.0 |
|
|
t = self.get_time_step_poly(t_T, t_0, N, device, rho) |
|
|
elif skip_type == "edm": |
|
|
rho = 7.0 |
|
|
t = self.get_time_step_edm(t_T, t_0, N, device, rho) |
|
|
t_t = self.get_time_step_edm_t(t_T, t_0, N, device, rho) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif "poly" in skip_type: |
|
|
rho = float(skip_type.split("_")[-1]) |
|
|
t = self.get_time_step_poly(t_T, t_0, N, device, rho) |
|
|
elif skip_type == "dmn": |
|
|
optimizer = StepOptim(self.noise_schedule) |
|
|
t, _ = optimizer.get_ts_lambdas(N, t_0) |
|
|
t = t.to(device).to(torch.float32) |
|
|
print(t) |
|
|
return t |
|
|
else: |
|
|
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) |
|
|
|
|
|
return t |
|
|
|
|
|
def append_zero(self, x): |
|
|
return torch.cat([x, x.new_zeros([1])]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_time_step_poly(self, t_T, t_0, N, device, rho=2.0): |
|
|
mono_sequence = torch.arange(0, N+1).pow(rho).to(device) |
|
|
sequence_min = mono_sequence.min() |
|
|
sequence_max = mono_sequence.max() |
|
|
t_max = t_T |
|
|
t_min = t_0 |
|
|
ts = t_min + (t_max - t_min) * (mono_sequence - sequence_min) / (sequence_max - sequence_min) |
|
|
return ts.flip(0) |
|
|
|
|
|
def get_time_step_edm_t(self, t_T, t_0, N, device, rho=7.0): |
|
|
t_min: float = t_0 |
|
|
t_max: float = t_T |
|
|
ramp = torch.linspace(0, 1, N + 1).to(device) |
|
|
min_inv_rho = t_min ** (1 / rho) |
|
|
max_inv_rho = t_max ** (1 / rho) |
|
|
ts = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
|
|
return ts |
|
|
|
|
|
def get_time_step_edm(self, t_T, t_0, N, device, rho=7.0): |
|
|
if isinstance(self.noise_schedule, NoiseScheduleVE): |
|
|
sigma_min = self.noise_schedule.marginal_std(t_0).to(device) |
|
|
sigma_max = self.noise_schedule.marginal_std(t_T).to(device) |
|
|
else: |
|
|
sigma_min = t_0 |
|
|
sigma_max = t_T |
|
|
|
|
|
ramp = torch.linspace(0, 1, N + 1).to(device) |
|
|
min_inv_rho = sigma_min ** (1 / rho) |
|
|
max_inv_rho = sigma_max ** (1 / rho) |
|
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
|
|
if isinstance(self.noise_schedule, NoiseScheduleVE): |
|
|
ts = self.noise_schedule.inverse_std(sigmas) |
|
|
else: |
|
|
ts = sigmas |
|
|
return ts |
|
|
|
|
|
def prepare_learn_timesteps(self, load_from, load_rs=False, device=None): |
|
|
|
|
|
timesteps = torch.load(load_from)['best_t_steps'].to(device) |
|
|
|
|
|
length = timesteps.shape[0] // 2 |
|
|
timesteps2 = timesteps[length:] |
|
|
timesteps = timesteps[:length] |
|
|
|
|
|
if load_rs: |
|
|
try: |
|
|
rs = torch.load(load_from)['best_rs'].to(device) |
|
|
rs2 = rs[length:] |
|
|
rs = rs[:length] |
|
|
except: |
|
|
rs = [0.5] * length |
|
|
rs2 = rs |
|
|
return timesteps, timesteps2, rs, rs2 |
|
|
|
|
|
return timesteps, timesteps2 |
|
|
|
|
|
|
|
|
def prepare_timesteps(self, steps=None, t_start=None, t_end=None, skip_type=None, device=None, load_from=None): |
|
|
if load_from is not None and os.path.isfile(load_from): |
|
|
timesteps, timesteps2 = self.prepare_learn_timesteps(load_from=load_from, device=device) |
|
|
else: |
|
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device) |
|
|
timesteps2 = timesteps |
|
|
return timesteps, timesteps2 |
|
|
|
|
|
|
|
|
def prepare_timesteps_single(self, steps, NFEs, t_start, t_end, flags, device, skip_type='time_uniform'): |
|
|
if flags.learn: |
|
|
timesteps, timesteps2, rs, rs2 = self.prepare_learn_timesteps(load_from=flags.load_from, load_rs=True, device=device) |
|
|
else: |
|
|
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device) |
|
|
timesteps2 = timesteps |
|
|
rs = [0.5] * steps |
|
|
rs2 = rs |
|
|
return timesteps, timesteps2, rs, rs2 |
|
|
|
|
|
|
|
|
def sample(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def sample_simple(self, model_fn, x, timesteps, timesteps2=None, condition=None, unconditional_condition=None, **kwargs): |
|
|
pass |
|
|
|
|
|
def dynamic_thresholding_fn(self, x0, t): |
|
|
""" |
|
|
The dynamic thresholding method.(not used by anything so far) |
|
|
""" |
|
|
dims = x0.dim() |
|
|
p = self.dynamic_thresholding_ratio |
|
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) |
|
|
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) |
|
|
x0 = torch.clamp(x0, -s, s) / s |
|
|
return x0 |
|
|
|
|
|
|
|
|
class GUIDEDSolver(ODESolver): |
|
|
def __init__( |
|
|
self, |
|
|
noise_schedule, |
|
|
algorithm_type="data_prediction", |
|
|
correcting_x0_fn=None, |
|
|
): |
|
|
super().__init__(noise_schedule, algorithm_type, correcting_x0_fn) |
|
|
self.noise_schedule = noise_schedule |
|
|
assert algorithm_type in ["data_prediction", "noise_prediction"] |
|
|
self.predict_x0 = algorithm_type == "data_prediction" |
|
|
self.correcting_x0_fn = correcting_x0_fn |
|
|
|
|
|
@abstractmethod |
|
|
def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs): |
|
|
pass |
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs): |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def sample(self, x, steps, t_start, t_end, order, skip_type, flags): |
|
|
pass |
|
|
|
|
|
|
|
|
class MultiStepODESolver(GUIDEDSolver): |
|
|
|
|
|
def __init__(self, model_fn, noise_schedule, algorithm_type="data_prediction"): |
|
|
''' |
|
|
algorithm_type needs to be data_prediction |
|
|
''' |
|
|
super().__init__(model_fn, noise_schedule, algorithm_type) |
|
|
|
|
|
@abstractmethod |
|
|
def _one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order=None, update_list=False, first=True): |
|
|
pass |
|
|
|
|
|
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', flags=None): |
|
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end |
|
|
t_T = self.noise_schedule.T if t_start is None else t_start |
|
|
device = x.device |
|
|
timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from) |
|
|
|
|
|
with torch.no_grad(): |
|
|
return self.forward_sample_simple(x, timesteps, timesteps2, order=order, return_image_list=False) |
|
|
|
|
|
def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs): |
|
|
assert 'order' in kwargs |
|
|
order = kwargs['order'] |
|
|
if timesteps2 is None: |
|
|
timesteps2 = timesteps |
|
|
step = 0 |
|
|
numsteps = len(timesteps) - 1 |
|
|
with torch.no_grad(): |
|
|
t_student1 = timesteps[step] |
|
|
t_student2 = timesteps2[step] |
|
|
t_prev_list_student = [t_student1] |
|
|
x_next_ = latent.clone() |
|
|
|
|
|
denoised_T = self.model_fn(x_next_, t_student2) |
|
|
model_prev_list_student = [denoised_T] |
|
|
|
|
|
if return_image_list: |
|
|
image_list = [] |
|
|
image_list.append(x_next_) |
|
|
|
|
|
for step in range(1, order): |
|
|
t1 = timesteps[step] |
|
|
t2 = timesteps2[step] |
|
|
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_, order, update_list=True, first=True) |
|
|
if return_image_list: |
|
|
image_list.append(x_next_) |
|
|
|
|
|
for step in range(order, numsteps + 1): |
|
|
t1 = timesteps[step] |
|
|
t2 = timesteps2[step] |
|
|
step_order = min(order, numsteps + 1 - step) |
|
|
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_, order, update_list=True, first=False) |
|
|
if return_image_list: |
|
|
image_list.append(x_next_) |
|
|
if return_image_list: |
|
|
return image_list |
|
|
return x_next_ |
|
|
|
|
|
def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs): |
|
|
assert 'order' in kwargs |
|
|
order = kwargs['order'] |
|
|
assert timesteps is None or len(timesteps) == len(image_list) |
|
|
numsteps = len(image_list) - 1 |
|
|
|
|
|
for ele in image_list: |
|
|
ele.requires_grad = True |
|
|
ele.retain_grad() |
|
|
|
|
|
for step in range(numsteps, order - 1, -1): |
|
|
if dis_model is not None: |
|
|
timesteps, timesteps2 = dis_model() |
|
|
else: |
|
|
timesteps2 = timesteps2 if timesteps2 is not None else timesteps |
|
|
|
|
|
t1 = timesteps[step] |
|
|
t2 = timesteps2[step] |
|
|
|
|
|
t_prev_list_student = [timesteps[step - i - 1] for i in range(order)][::-1] |
|
|
t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(order)][::-1] |
|
|
this_image_list = [image_list[step - i - 1] for i in range(order)][::-1] |
|
|
model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))] |
|
|
|
|
|
x_next_input = image_list[step - 1] |
|
|
|
|
|
step_order = min(order, numsteps + 1 - step) |
|
|
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_input, update_list=False) |
|
|
x_next_.backward(grad, retain_graph=False) |
|
|
grad = x_next_input.grad.detach() |
|
|
|
|
|
|
|
|
for step in range(order - 1, 0, -1): |
|
|
if dis_model is not None: |
|
|
timesteps, timesteps2 = dis_model() |
|
|
else: |
|
|
timesteps2 = timesteps2 if timesteps2 is not None else timesteps |
|
|
|
|
|
t1 = timesteps[step] |
|
|
t2 = timesteps2[step] |
|
|
t_prev_list_student = [timesteps[step - i - 1] for i in range(step)][::-1] |
|
|
t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(step)][::-1] |
|
|
this_image_list = [image_list[step - i - 1] for i in range(step)][::-1] |
|
|
model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))] |
|
|
|
|
|
x_next_input = image_list[step - 1] |
|
|
|
|
|
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_input, update_list=False) |
|
|
x_next_.backward(grad, retain_graph=False) |
|
|
grad = x_next_input.grad.detach() |
|
|
|
|
|
return grad |
|
|
|
|
|
|