BonanDing's picture
update lfs
8652b14
raw
history blame
20.2 kB
from typing import Optional, Callable
from collections import namedtuple
from omegaconf import DictConfig
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
from .dit import DiT_models
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
class Diffusion(nn.Module):
# Special thanks to lucidrains for the implementation of the base Diffusion model
# https://github.com/lucidrains/denoising-diffusion-pytorch
def __init__(
self,
x_shape: torch.Size,
reference_length: int,
action_cond_dim: int,
pose_cond_dim,
is_causal: bool,
cfg: DictConfig,
is_dit: bool=False,
use_plucker=False,
relative_embedding=False,
state_embed_only_on_qk=False,
use_memory_attention=False,
add_timestamp_embedding=False,
ref_mode='sequential'
):
super().__init__()
self.cfg = cfg
self.x_shape = x_shape
self.action_cond_dim = action_cond_dim
self.timesteps = cfg.timesteps
self.sampling_timesteps = cfg.sampling_timesteps
self.beta_schedule = cfg.beta_schedule
self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
self.objective = cfg.objective
self.use_fused_snr = cfg.use_fused_snr
self.snr_clip = cfg.snr_clip
self.cum_snr_decay = cfg.cum_snr_decay
self.ddim_sampling_eta = cfg.ddim_sampling_eta
self.clip_noise = cfg.clip_noise
self.arch = cfg.architecture
self.stabilization_level = cfg.stabilization_level
self.is_causal = is_causal
self.is_dit = is_dit
self.reference_length = reference_length
self.pose_cond_dim = pose_cond_dim
self.use_plucker = use_plucker
self.relative_embedding = relative_embedding
self.state_embed_only_on_qk = state_embed_only_on_qk
self.use_memory_attention = use_memory_attention
self.add_timestamp_embedding = add_timestamp_embedding
self.ref_mode = ref_mode
self._build_model()
self._build_buffer()
def _build_model(self):
x_channel = self.x_shape[0]
if self.is_dit:
self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
pose_cond_dim=self.pose_cond_dim, reference_length=self.reference_length,
use_plucker=self.use_plucker,
relative_embedding=self.relative_embedding,
state_embed_only_on_qk=self.state_embed_only_on_qk,
use_memory_attention=self.use_memory_attention,
add_timestamp_embedding=self.add_timestamp_embedding,
ref_mode=self.ref_mode)
else:
raise NotImplementedError
def _build_buffer(self):
if self.beta_schedule == "linear":
beta_schedule_fn = linear_beta_schedule
elif self.beta_schedule == "cosine":
beta_schedule_fn = cosine_beta_schedule
elif self.beta_schedule == "sigmoid":
beta_schedule_fn = sigmoid_beta_schedule
else:
raise ValueError(f"unknown beta schedule {self.beta_schedule}")
betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# sampling related parameters
assert self.sampling_timesteps <= self.timesteps
self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer(
"posterior_log_variance_clipped",
torch.log(posterior_variance.clamp(min=1e-20)),
)
register_buffer(
"posterior_mean_coef1",
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
)
register_buffer(
"posterior_mean_coef2",
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
)
# calculate p2 reweighting
# register_buffer(
# "p2_loss_weight",
# (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
# ** -self.p2_loss_weight_gamma,
# )
# derive loss weight
# https://arxiv.org/abs/2303.09556
# snr: signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
clipped_snr = snr.clone()
clipped_snr.clamp_(max=self.snr_clip)
register_buffer("clipped_snr", clipped_snr)
register_buffer("snr", snr)
def add_shape_channels(self, x):
return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
def model_predictions(self, x, t, action_cond=None, current_frame=None,
pose_cond=None, mode="training", reference_length=None, frame_idx=None):
x = x.permute(1,0,2,3,4)
action_cond = action_cond.permute(1,0,2)
if pose_cond is not None and pose_cond[0] is not None:
try:
pose_cond = pose_cond.permute(1,0,2)
except:
pass
t = t.permute(1,0)
model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
mode=mode, reference_length=reference_length, frame_idx=frame_idx)
model_output = model_output.permute(1,0,2,3,4)
x = x.permute(1,0,2,3,4)
t = t.permute(1,0)
if self.objective == "pred_noise":
pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
x_start = self.predict_start_from_noise(x, t, pred_noise)
elif self.objective == "pred_x0":
x_start = model_output
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == "pred_v":
v = model_output
x_start = self.predict_start_from_v(x, t, v)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start, model_output)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_mean_variance(self, x, t, action_cond=None, pose_cond=None, reference_length=None):
model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
pose_cond=pose_cond, reference_length=reference_length,
frame_idx=frame_idx)
x_start = model_pred.pred_x_start
return self.q_posterior(x_start=x_start, x_t=x, t=t)
def compute_loss_weights(self, noise_levels: torch.Tensor):
snr = self.snr[noise_levels]
clipped_snr = self.clipped_snr[noise_levels]
normalized_clipped_snr = clipped_snr / self.snr_clip
normalized_snr = snr / self.snr_clip
if not self.use_fused_snr:
# min SNR reweighting
match self.objective:
case "pred_noise":
return clipped_snr / snr
case "pred_x0":
return clipped_snr
case "pred_v":
return clipped_snr / (snr + 1)
cum_snr = torch.zeros_like(normalized_snr)
for t in range(0, noise_levels.shape[0]):
if t == 0:
cum_snr[t] = normalized_clipped_snr[t]
else:
cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
match self.objective:
case "pred_noise":
return clipped_fused_snr / fused_snr
case "pred_x0":
return clipped_fused_snr * self.snr_clip
case "pred_v":
return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
case _:
raise ValueError(f"unknown objective {self.objective}")
def forward(
self,
x: torch.Tensor,
action_cond: Optional[torch.Tensor],
pose_cond,
noise_levels: torch.Tensor,
reference_length,
frame_idx=None
):
noise = torch.randn_like(x)
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx)
pred = model_pred.model_out
x_pred = model_pred.pred_x_start
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x
elif self.objective == "pred_v":
target = self.predict_v(x, noise_levels, noise)
else:
raise ValueError(f"unknown objective {self.objective}")
# 训练的时候每个frame随便给噪声
loss = F.mse_loss(pred, target.detach(), reduction="none")
loss_weight = self.compute_loss_weights(noise_levels)
loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
loss = loss * loss_weight
return x_pred, loss
def sample_step(
self,
x: torch.Tensor,
action_cond: Optional[torch.Tensor],
pose_cond,
curr_noise_level: torch.Tensor,
next_noise_level: torch.Tensor,
guidance_fn: Optional[Callable] = None,
current_frame=None,
mode="training",
reference_length=None,
frame_idx=None
):
real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
# convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
curr_noise_level = real_steps[curr_noise_level]
next_noise_level = real_steps[next_noise_level]
if self.is_ddim_sampling:
return self.ddim_sample_step(
x=x,
action_cond=action_cond,
pose_cond=pose_cond,
curr_noise_level=curr_noise_level,
next_noise_level=next_noise_level,
guidance_fn=guidance_fn,
current_frame=current_frame,
mode=mode,
reference_length=reference_length,
frame_idx=frame_idx
)
# FIXME: temporary code for checking ddpm sampling
assert torch.all(
(curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
), "Wrong noise level given for ddpm sampling."
assert (
self.sampling_timesteps == self.timesteps
), "sampling_timesteps should be equal to timesteps for ddpm sampling."
return self.ddpm_sample_step(
x=x,
action_cond=action_cond,
pose_cond=pose_cond,
curr_noise_level=curr_noise_level,
guidance_fn=guidance_fn,
reference_length=reference_length,
frame_idx=frame_idx
)
def ddpm_sample_step(
self,
x: torch.Tensor,
action_cond: Optional[torch.Tensor],
pose_cond,
curr_noise_level: torch.Tensor,
guidance_fn: Optional[Callable] = None,
reference_length=None,
frame_idx=None,
):
clipped_curr_noise_level = torch.where(
curr_noise_level < 0,
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
curr_noise_level,
)
# treating as stabilization would require us to scale with sqrt of alpha_cum
orig_x = x.clone().detach()
scaled_context = self.q_sample(
x,
clipped_curr_noise_level,
noise=torch.zeros_like(x),
)
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
if guidance_fn is not None:
raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
else:
model_mean, _, model_log_variance = self.p_mean_variance(
x=x,
t=clipped_curr_noise_level,
action_cond=action_cond,
pose_cond=pose_cond,
reference_length=reference_length,
frame_idx=frame_idx
)
noise = torch.where(
self.add_shape_channels(clipped_curr_noise_level > 0),
torch.randn_like(x),
0,
)
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
# only update frames where the noise level decreases
return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
def ddim_sample_step(
self,
x: torch.Tensor,
action_cond: Optional[torch.Tensor],
pose_cond,
curr_noise_level: torch.Tensor,
next_noise_level: torch.Tensor,
guidance_fn: Optional[Callable] = None,
current_frame=None,
mode="training",
reference_length=None,
frame_idx=None
):
# convert noise level -1 to self.stabilization_level - 1
clipped_curr_noise_level = torch.where(
curr_noise_level < 0,
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
curr_noise_level,
)
# treating as stabilization would require us to scale with sqrt of alpha_cum
orig_x = x.clone().detach()
scaled_context = self.q_sample(
x,
clipped_curr_noise_level,
noise=torch.zeros_like(x),
)
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
alpha = self.alphas_cumprod[clipped_curr_noise_level]
alpha_next = torch.where(
next_noise_level < 0,
torch.ones_like(next_noise_level),
self.alphas_cumprod[next_noise_level],
)
sigma = torch.where(
next_noise_level < 0,
torch.zeros_like(next_noise_level),
self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
)
c = (1 - alpha_next - sigma**2).sqrt()
alpha_next = self.add_shape_channels(alpha_next)
c = self.add_shape_channels(c)
sigma = self.add_shape_channels(sigma)
if guidance_fn is not None:
with torch.enable_grad():
x = x.detach().requires_grad_()
model_pred = self.model_predictions(
x=x,
t=clipped_curr_noise_level,
action_cond=action_cond,
pose_cond=pose_cond,
current_frame=current_frame,
mode=mode,
reference_length=reference_length,
frame_idx=frame_idx
)
guidance_loss = guidance_fn(model_pred.pred_x_start)
grad = -torch.autograd.grad(
guidance_loss,
x,
)[0]
pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
else:
# print(clipped_curr_noise_level)
model_pred = self.model_predictions(
x=x,
t=clipped_curr_noise_level,
action_cond=action_cond,
pose_cond=pose_cond,
current_frame=current_frame,
mode=mode,
reference_length=reference_length,
frame_idx=frame_idx
)
x_start = model_pred.pred_x_start
pred_noise = model_pred.pred_noise
noise = torch.randn_like(x)
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
# only update frames where the noise level decreases
mask = curr_noise_level == next_noise_level
x_pred = torch.where(
self.add_shape_channels(mask),
orig_x,
x_pred,
)
return x_pred