File size: 21,979 Bytes
d382778 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 |
import torch
from abc import ABC, abstractmethod
import os
from noise_schedulers import NoiseScheduleVE, NoiseScheduleVP
import numpy as np
from scipy.optimize import minimize
from scipy.optimize import LinearConstraint
class StepOptim(object):
def __init__(self, ns):
super().__init__()
self.ns = ns
self.T = self.ns.T # t_T of diffusion sampling, for VP models, T=1.0; for EDM models, T=80.0
self.is_latent_space = isinstance(self.ns, NoiseScheduleVP)
def alpha(self, t):
t = torch.as_tensor(t, dtype = torch.float64)
return self.ns.marginal_alpha(t).numpy()
def sigma(self, t):
return np.sqrt(1 - self.alpha(t) * self.alpha(t))
def lambda_func(self, t):
return np.log(self.alpha(t)/self.sigma(t))
def edm_lambda_func(self, t):
return np.log(self.alpha(t)/self.edm_sigma(t))
def H0(self, h):
return np.exp(h) - 1
def H1(self, h):
return np.exp(h) * h - self.H0(h)
def H2(self, h):
return np.exp(h) * h * h - 2 * self.H1(h)
def H3(self, h):
return np.exp(h) * h * h * h - 3 * self.H2(h)
def inverse_lambda(self, lamb):
lamb = torch.as_tensor(lamb, dtype = torch.float64)
return self.ns.inverse_lambda(lamb)
def edm_sigma(self, t):
return t
def edm_inverse_sigma(self, edm_sigma):
alpha = 1 / (edm_sigma*edm_sigma+1).sqrt()
sigma = alpha*edm_sigma
lambda_t = np.log(alpha/sigma)
t = self.inverse_lambda(lambda_t)
return t
def sel_lambdas_lof_obj(self, lambda_vec, eps):
lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func
lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item()
lambda_vec_ext = np.concatenate((np.array([lambda_T]), lambda_vec, np.array([lambda_eps])))
N = len(lambda_vec_ext) - 1
hv = np.zeros(N)
for i in range(N):
hv[i] = lambda_vec_ext[i+1] - lambda_vec_ext[i]
elv = np.exp(lambda_vec_ext)
emlv_sq = np.exp(-2*lambda_vec_ext)
alpha_vec = 1./np.sqrt(1+emlv_sq)
sigma_vec = 1./np.sqrt(1+np.exp(2*lambda_vec_ext))
if self.is_latent_space:
data_err_vec = (sigma_vec**2)/alpha_vec
else:
data_err_vec = (sigma_vec**1)/alpha_vec
# for pixel-space diffusion models, we empirically find (sigma_vec**1)/alpha_vec will be better
if N <= 7:
truncNum = 3 # For NFEs <= 7, set truncNum = 3 to avoid numerical instability; for NFEs > 7, truncNum = 0
else:
truncNum = 0
res = 0.
c_vec = np.zeros(N)
for s in range(N):
if s in [0, N-1]:
n, kp = s, 1
J_n_kp_0 = elv[n+1] - elv[n]
res += abs(J_n_kp_0 * data_err_vec[n])
elif s in [1, N-2]:
n, kp = s-1, 2
J_n_kp_0 = -elv[n+1] * self.H1(hv[n+1]) / hv[n]
J_n_kp_1 = elv[n+1] * (self.H1(hv[n+1])+hv[n]*self.H0(hv[n+1])) / hv[n]
if s >= truncNum:
c_vec[n] += data_err_vec[n] * J_n_kp_0
c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1
else:
res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2)
else:
n, kp = s-2, 3
J_n_kp_0 = elv[n+2] * (self.H2(hv[n+2])+hv[n+1]*self.H1(hv[n+2])) / (hv[n]*(hv[n]+hv[n+1]))
J_n_kp_1 = -elv[n+2] * (self.H2(hv[n+2])+(hv[n]+hv[n+1])*self.H1(hv[n+2])) / (hv[n]*hv[n+1])
J_n_kp_2 = elv[n+2] * (self.H2(hv[n+2])+(2*hv[n+1]+hv[n])*self.H1(hv[n+2])+hv[n+1]*(hv[n]+hv[n+1])*self.H0(hv[n+2])) / (hv[n+1]*(hv[n]+hv[n+1]))
if s >= truncNum:
c_vec[n] += data_err_vec[n] * J_n_kp_0
c_vec[n+1] += data_err_vec[n+1] * J_n_kp_1
c_vec[n+2] += data_err_vec[n+2] * J_n_kp_2
else:
res += np.sqrt((data_err_vec[n] * J_n_kp_0)**2 + (data_err_vec[n+1] * J_n_kp_1)**2 + (data_err_vec[n+2] * J_n_kp_2)**2)
res += sum(abs(c_vec))
return res
def get_ts_lambdas(self, N, eps):
if self.is_latent_space:
initType = "unif_t"
else:
initType = "unif"
# eps is t_0 of diffusion sampling, e.g. 1e-3 for VP models
# initType: initTypes with '_origin' are baseline time step discretizations (without optimization)
# initTypes without '_origin' are optimized time step discretizations with corresponding baseline
# time step discretizations as initializations. For latent-space diffusion models, 'unif_t' is recommended.
# For pixel-space diffusion models, 'unif' is recommended (which is logSNR initialization)
lambda_func = self.lambda_func if self.is_latent_space else self.edm_lambda_func
lambda_eps, lambda_T = lambda_func(eps).item(), lambda_func(self.T).item()
# constraints
constr_mat = np.zeros((N, N-1))
for i in range(N-1):
constr_mat[i][i] = 1.
constr_mat[i+1][i] = -1
lb_vec = np.zeros(N)
lb_vec[0], lb_vec[-1] = lambda_T, -lambda_eps
ub_vec = np.zeros(N)
for i in range(N):
ub_vec[i] = np.inf
linear_constraint = LinearConstraint(constr_mat, lb_vec, ub_vec)
# initial vector
if initType in ['unif', 'unif_origin']:
lambda_vec_ext = torch.linspace(lambda_T, lambda_eps, N+1)
elif initType in ['unif_t', 'unif_t_origin']:
t_vec = torch.linspace(self.T, eps, N+1)
lambda_vec_ext = self.lambda_func(t_vec)
elif initType in ['edm', 'edm_origin']:
rho = 7
edm_sigma_min, edm_sigma_max = self.edm_sigma(eps).item(), self.edm_sigma(self.T).item()
edm_sigma_vec = torch.linspace(edm_sigma_max**(1. / rho), edm_sigma_min**(1. / rho), N + 1).pow(rho)
t_vec = self.edm_inverse_sigma(edm_sigma_vec)
lambda_vec_ext = self.lambda_func(t_vec)
elif initType in ['quad', 'quad_origin']:
t_order = 2
t_vec = torch.linspace(self.T**(1./t_order), eps**(1./t_order), N+1).pow(t_order)
lambda_vec_ext = self.lambda_func(t_vec)
else:
print('InitType not found!')
return
if initType in ['unif_origin', 'unif_t_origin', 'edm_origin', 'quad_origin']:
lambda_res = lambda_vec_ext
t_res = torch.tensor(self.inverse_lambda(lambda_res))
else:
lambda_vec_init = np.array(lambda_vec_ext[1:-1])
res = minimize(self.sel_lambdas_lof_obj, lambda_vec_init, method='trust-constr', args=(eps), constraints=[linear_constraint], options={'verbose': 1})
lambda_res = torch.tensor(np.concatenate((np.array([lambda_T]), res.x, np.array([lambda_eps]))))
t_res = torch.tensor(self.inverse_lambda(lambda_res))
return t_res, lambda_res
def expand_dims(x, dims):
for _ in range(dims):
x = x.unsqueeze(-1)
return x
def update_lists(t_list, model_list, t_, model_x, order, first=False):
if first:
t_list.append(t_)
model_list.append(model_x)
return
for m in range(order - 1):
t_list[m] = t_list[m + 1]
model_list[m] = model_list[m + 1]
t_list[-1] = t_
model_list[-1] = model_x
class ODESolver(ABC):
def __init__(
self,
noise_schedule,
algorithm_type="data_prediction",
correcting_x0_fn=None,
):
self.noise_schedule = noise_schedule # noiseScheduleVP
assert algorithm_type in ["data_prediction", "noise_prediction"]
self.predict_x0 = algorithm_type == "data_prediction" # true
self.correcting_x0_fn = correcting_x0_fn # None
def dx_dt_for_blackbox_solvers(self, x, t1, t2):
'''
for edm, dx_dt = noise
'''
ft = self.noise_schedule.ft(t1) # should be 0.
gt = self.noise_schedule.gt(t1) # should be 1.
sigma_t = self.noise_schedule.marginal_std(t1)
noise = self.noise_prediction_fn(x, t2)
return ft * x + gt ** 2 / (2 * sigma_t) * noise
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with corrector).
"""
noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0)
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
"""
if skip_type == 'logSNR':
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
t = self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
t = torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
rho = 2.0
t = self.get_time_step_poly(t_T, t_0, N, device, rho)
elif skip_type == "edm":
rho = 7.0 # 7.0 is the value used in the paper
t = self.get_time_step_edm(t_T, t_0, N, device, rho)
t_t = self.get_time_step_edm_t(t_T, t_0, N, device, rho)
# distance = (t - t_t).abs().max()
# breakpoint()
# if distance > 1e-6:
# raise ValueError("The time steps are not equal")
elif "poly" in skip_type:
rho = float(skip_type.split("_")[-1])
t = self.get_time_step_poly(t_T, t_0, N, device, rho)
elif skip_type == "dmn":
optimizer = StepOptim(self.noise_schedule)
t, _ = optimizer.get_ts_lambdas(N, t_0)
t = t.to(device).to(torch.float32)
print(t)
return t
else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
return t
def append_zero(self, x):
return torch.cat([x, x.new_zeros([1])])
# def get_time_step_poly(self, sigma_max, sigma_min, n, device, rho=7.0):
# """Constructs the noise schedule of Karras et al. (2022)."""
# ramp = torch.linspace(0, 1, n)
# min_inv_rho = sigma_min ** (1 / rho)
# max_inv_rho = sigma_max ** (1 / rho)
# sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# return self.append_zero(sigmas).to(device)
# def get_time_step_poly(self, t_T, t_0, N, device, rho=7.0):
# t_min: float = t_0
# t_max: float = t_T
# ramp = torch.linspace(0, 1, N + 1).to(device)
# min_inv_rho = t_min ** (1 / rho)
# max_inv_rho = t_max ** (1 / rho)
# ts = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# return ts
def get_time_step_poly(self, t_T, t_0, N, device, rho=2.0):
mono_sequence = torch.arange(0, N+1).pow(rho).to(device)
sequence_min = mono_sequence.min()
sequence_max = mono_sequence.max()
t_max = t_T
t_min = t_0
ts = t_min + (t_max - t_min) * (mono_sequence - sequence_min) / (sequence_max - sequence_min)
return ts.flip(0)
def get_time_step_edm_t(self, t_T, t_0, N, device, rho=7.0):
t_min: float = t_0
t_max: float = t_T
ramp = torch.linspace(0, 1, N + 1).to(device)
min_inv_rho = t_min ** (1 / rho)
max_inv_rho = t_max ** (1 / rho)
ts = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return ts
def get_time_step_edm(self, t_T, t_0, N, device, rho=7.0):
if isinstance(self.noise_schedule, NoiseScheduleVE):
sigma_min = self.noise_schedule.marginal_std(t_0).to(device)
sigma_max = self.noise_schedule.marginal_std(t_T).to(device)
else:
sigma_min = t_0
sigma_max = t_T
ramp = torch.linspace(0, 1, N + 1).to(device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
if isinstance(self.noise_schedule, NoiseScheduleVE):
ts = self.noise_schedule.inverse_std(sigmas)
else:
ts = sigmas
return ts
def prepare_learn_timesteps(self, load_from, load_rs=False, device=None):
# timesteps = torch.load(os.path.join(load_from, 'best.pt'))['best_t_steps']
timesteps = torch.load(load_from)['best_t_steps'].to(device)
length = timesteps.shape[0] // 2
timesteps2 = timesteps[length:]
timesteps = timesteps[:length]
if load_rs:
try:
rs = torch.load(load_from)['best_rs'].to(device)
rs2 = rs[length:]
rs = rs[:length]
except:
rs = [0.5] * length
rs2 = rs
return timesteps, timesteps2, rs, rs2
return timesteps, timesteps2
def prepare_timesteps(self, steps=None, t_start=None, t_end=None, skip_type=None, device=None, load_from=None):
if load_from is not None and os.path.isfile(load_from):
timesteps, timesteps2 = self.prepare_learn_timesteps(load_from=load_from, device=device)
else:
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device)
timesteps2 = timesteps
return timesteps, timesteps2
def prepare_timesteps_single(self, steps, NFEs, t_start, t_end, flags, device, skip_type='time_uniform'):
if flags.learn:
timesteps, timesteps2, rs, rs2 = self.prepare_learn_timesteps(load_from=flags.load_from, load_rs=True, device=device)
else:
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_start, t_0=t_end, N=steps, device=device)
timesteps2 = timesteps
rs = [0.5] * steps
rs2 = rs
return timesteps, timesteps2, rs, rs2
def sample(self, *args, **kwargs):
pass
@abstractmethod
def sample_simple(self, model_fn, x, timesteps, timesteps2=None, condition=None, unconditional_condition=None, **kwargs):
pass
def dynamic_thresholding_fn(self, x0, t):
"""
The dynamic thresholding method.(not used by anything so far)
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
class GUIDEDSolver(ODESolver):
def __init__(
self,
noise_schedule,
algorithm_type="data_prediction",
correcting_x0_fn=None,
):
super().__init__(noise_schedule, algorithm_type, correcting_x0_fn)
self.noise_schedule = noise_schedule # noiseScheduleVP
assert algorithm_type in ["data_prediction", "noise_prediction"]
self.predict_x0 = algorithm_type == "data_prediction" # true
self.correcting_x0_fn = correcting_x0_fn # None
@abstractmethod
def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs):
pass
@abstractmethod
def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs):
pass
@abstractmethod
def sample(self, x, steps, t_start, t_end, order, skip_type, flags):
pass
class MultiStepODESolver(GUIDEDSolver):
def __init__(self, model_fn, noise_schedule, algorithm_type="data_prediction"):
'''
algorithm_type needs to be data_prediction
'''
super().__init__(model_fn, noise_schedule, algorithm_type)
@abstractmethod
def _one_step(self, t1, t2, t_prev_list, model_prev_list, step, x_next, order=None, update_list=False, first=True):
pass
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', flags=None):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
timesteps, timesteps2 = self.prepare_timesteps(steps=steps, t_start=t_T, t_end=t_0, skip_type=skip_type, device=device, load_from=flags.load_from)
with torch.no_grad():
return self.forward_sample_simple(x, timesteps, timesteps2, order=order, return_image_list=False)
def forward_sample_simple(self, latent, timesteps, timesteps2=None, return_image_list=False, **kwargs):
assert 'order' in kwargs
order = kwargs['order']
if timesteps2 is None:
timesteps2 = timesteps
step = 0
numsteps = len(timesteps) - 1
with torch.no_grad():
t_student1 = timesteps[step]
t_student2 = timesteps2[step]
t_prev_list_student = [t_student1]
x_next_ = latent.clone() # bs x 3 x 256 x 256
denoised_T = self.model_fn(x_next_, t_student2)
model_prev_list_student = [denoised_T]
if return_image_list:
image_list = []
image_list.append(x_next_)
for step in range(1, order):
t1 = timesteps[step]
t2 = timesteps2[step]
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_, order, update_list=True, first=True)
if return_image_list:
image_list.append(x_next_)
for step in range(order, numsteps + 1):
t1 = timesteps[step]
t2 = timesteps2[step]
step_order = min(order, numsteps + 1 - step)
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_, order, update_list=True, first=False)
if return_image_list:
image_list.append(x_next_)
if return_image_list:
return image_list
return x_next_
def backward_sample_simple(self, image_list, grad, timesteps=None, timesteps2=None, dis_model=None, **kwargs):
assert 'order' in kwargs
order = kwargs['order']
assert timesteps is None or len(timesteps) == len(image_list)
numsteps = len(image_list) - 1
for ele in image_list:
ele.requires_grad = True
ele.retain_grad()
for step in range(numsteps, order - 1, -1):
if dis_model is not None:
timesteps, timesteps2 = dis_model()
else:
timesteps2 = timesteps2 if timesteps2 is not None else timesteps
t1 = timesteps[step]
t2 = timesteps2[step]
t_prev_list_student = [timesteps[step - i - 1] for i in range(order)][::-1] # decrease
t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(order)][::-1] # decrease
this_image_list = [image_list[step - i - 1] for i in range(order)][::-1] # decrease
model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))]
x_next_input = image_list[step - 1] # use x_1 to predict x_0; use x_2 to predict x_1,..
step_order = min(order, numsteps + 1 - step)
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step_order, x_next_input, update_list=False) # x_0
x_next_.backward(grad, retain_graph=False) #
grad = x_next_input.grad.detach() # dL / dx_1
for step in range(order - 1, 0, -1): # 2, 1
if dis_model is not None:
timesteps, timesteps2 = dis_model()
else:
timesteps2 = timesteps2 if timesteps2 is not None else timesteps
t1 = timesteps[step]
t2 = timesteps2[step]
t_prev_list_student = [timesteps[step - i - 1] for i in range(step)][::-1] # decrease
t_prev_list_student2 = [timesteps2[step - i - 1] for i in range(step)][::-1] # decrease
this_image_list = [image_list[step - i - 1] for i in range(step)][::-1] # decrease
model_prev_list_student = [self.model_fn(this_image_list[i], t_prev_list_student2[i]) for i in range(len(t_prev_list_student2))]
x_next_input = image_list[step - 1] # x_T
x_next_ = self._one_step(t1, t2, t_prev_list_student, model_prev_list_student, step, x_next_input, update_list=False) # x_T-1
x_next_.backward(grad, retain_graph=False)
grad = x_next_input.grad.detach() # dL / dx_T #
return grad
|