ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from typing import Literal, Optional
from tqdm import tqdm
import numpy as np
import torch
from TorchJaekwon.Util.UtilTorch import UtilTorch
from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM
from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil
class DDIM(object):
def __init__(self, ddpm_model:DDPM):
self.ddpm_model:DDPM = ddpm_model
self.ddpm_num_timesteps:int = ddpm_model.timesteps
self.device:torch.device = UtilTorch.get_model_device(self.ddpm_model)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
is_mps = self.device == "mps" or self.device == torch.device("mps")
if is_mps and attr.dtype == torch.float64:
attr = attr.to(self.device, dtype=torch.float32)
else:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = DDIM.make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.ddpm_model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer("betas", to_torch(self.ddpm_model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.ddpm_model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = DDIM.make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@staticmethod
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out
@staticmethod
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
@torch.no_grad()
def infer(
self,
x_shape:tuple = None,
cond:Optional[dict] = None,
is_cond_unpack:bool = False,
num_steps:int = 50,
batch_size:int = None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=1.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
dynamic_threshold=None,
ucg_schedule=None,
):
_, cond, additional_data_dict = self.ddpm_model.preprocess(x_start = None, cond=cond)
if x_shape is None: x_shape = self.ddpm_model.get_x_shape(cond=cond)
if batch_size is not None: x_shape[0] = batch_size
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta, verbose=verbose)
samples, intermediates = self.ddim_sampling(
x_shape,
cond,
is_cond_unpack,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
)
return self.ddpm_model.postprocess(samples, additional_data_dict)
@torch.no_grad()
def ddim_sampling(
self,
x_shape:tuple = None,
cond:Optional[dict] = None,
is_cond_unpack:bool = False,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None,
):
device = self.ddpm_model.betas.device
b = x_shape[0]
if x_T is None:
img = torch.randn(x_shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.ddpm_model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(
img,
ts,
index=index,
cond = cond,
is_cond_unpack = is_cond_unpack,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(
self,
x,
t,
index,
cond:Optional[dict] = None,
is_cond_unpack:bool = False,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
model_output = self.ddpm_model.apply_model(x, t, cond, is_cond_unpack, cfg_scale = self.ddpm_model.cfg_scale)
if self.ddpm_model.model_output_type == "v_prediction":
e_t = self.ddpm_model.predict_noise_from_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.ddpm_model.parameterization == "eps", "not implemented"
e_t = score_corrector.modify_score(
self.ddpm_model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.ddpm_model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.ddpm_model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.ddpm_model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.ddpm_model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
if self.ddpm_model.model_output_type != "v_prediction":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.ddpm_model.predict_x_start_from_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.ddpm_model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * DiffusionUtil.noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(
self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
callback=None,
):
num_reference_steps = (
self.ddpm_num_timesteps
if use_original_steps
else self.ddim_timesteps.shape[0]
)
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc="Encoding Image"):
t = torch.full(
(x0.shape[0],), i, device=self.ddpm_model.device, dtype=torch.long
)
if unconditional_guidance_scale == 1.0:
noise_pred = self.ddpm_model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.ddpm_model.apply_model(
torch.cat((x_next, x_next)),
torch.cat((t, t)),
torch.cat((unconditional_conditioning, c)),
),
2,
)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond
)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = (
alphas_next[i].sqrt()
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
* noise_pred
)
x_next = xt_weighted + weighted_noise_pred
if (
return_intermediates
and i % (num_steps // return_intermediates) == 0
and i < num_steps - 1
):
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback:
callback(i)
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
if return_intermediates:
out.update({"intermediates": intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if callback:
callback(i)
return x_dec