| from typing import Optional, Callable |
| from collections import namedtuple |
| from omegaconf import DictConfig |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from einops import rearrange |
| from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract |
| from .dit import DiT_models |
|
|
| ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"]) |
|
|
|
|
| class Diffusion(nn.Module): |
| |
| |
|
|
| def __init__( |
| self, |
| x_shape: torch.Size, |
| reference_length: int, |
| action_cond_dim: int, |
| pose_cond_dim, |
| is_causal: bool, |
| cfg: DictConfig, |
| is_dit: bool=False, |
| use_plucker=False, |
| relative_embedding=False, |
| state_embed_only_on_qk=False, |
| use_memory_attention=False, |
| add_timestamp_embedding=False, |
| ref_mode='sequential' |
| ): |
| super().__init__() |
| self.cfg = cfg |
|
|
| self.x_shape = x_shape |
| self.action_cond_dim = action_cond_dim |
| self.timesteps = cfg.timesteps |
| self.sampling_timesteps = cfg.sampling_timesteps |
| self.beta_schedule = cfg.beta_schedule |
| self.schedule_fn_kwargs = cfg.schedule_fn_kwargs |
| self.objective = cfg.objective |
| self.use_fused_snr = cfg.use_fused_snr |
| self.snr_clip = cfg.snr_clip |
| self.cum_snr_decay = cfg.cum_snr_decay |
| self.ddim_sampling_eta = cfg.ddim_sampling_eta |
| self.clip_noise = cfg.clip_noise |
| self.arch = cfg.architecture |
| self.stabilization_level = cfg.stabilization_level |
| self.is_causal = is_causal |
| self.is_dit = is_dit |
| self.reference_length = reference_length |
| self.pose_cond_dim = pose_cond_dim |
| self.use_plucker = use_plucker |
| self.relative_embedding = relative_embedding |
| self.state_embed_only_on_qk = state_embed_only_on_qk |
| self.use_memory_attention = use_memory_attention |
| self.add_timestamp_embedding = add_timestamp_embedding |
| self.ref_mode = ref_mode |
|
|
| self._build_model() |
| self._build_buffer() |
|
|
| def _build_model(self): |
| x_channel = self.x_shape[0] |
| if self.is_dit: |
| self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim, |
| pose_cond_dim=self.pose_cond_dim, reference_length=self.reference_length, |
| use_plucker=self.use_plucker, |
| relative_embedding=self.relative_embedding, |
| state_embed_only_on_qk=self.state_embed_only_on_qk, |
| use_memory_attention=self.use_memory_attention, |
| add_timestamp_embedding=self.add_timestamp_embedding, |
| ref_mode=self.ref_mode) |
| else: |
| raise NotImplementedError |
|
|
| def _build_buffer(self): |
| if self.beta_schedule == "linear": |
| beta_schedule_fn = linear_beta_schedule |
| elif self.beta_schedule == "cosine": |
| beta_schedule_fn = cosine_beta_schedule |
| elif self.beta_schedule == "sigmoid": |
| beta_schedule_fn = sigmoid_beta_schedule |
| else: |
| raise ValueError(f"unknown beta schedule {self.beta_schedule}") |
|
|
| betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs) |
|
|
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
| |
| assert self.sampling_timesteps <= self.timesteps |
| self.is_ddim_sampling = self.sampling_timesteps < self.timesteps |
|
|
| |
| register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) |
|
|
| register_buffer("betas", betas) |
| register_buffer("alphas_cumprod", alphas_cumprod) |
| register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) |
|
|
| |
|
|
| register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) |
| register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) |
| register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) |
| register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) |
| register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) |
|
|
| |
|
|
| posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|
|
| |
|
|
| register_buffer("posterior_variance", posterior_variance) |
|
|
| |
|
|
| register_buffer( |
| "posterior_log_variance_clipped", |
| torch.log(posterior_variance.clamp(min=1e-20)), |
| ) |
| register_buffer( |
| "posterior_mean_coef1", |
| betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), |
| ) |
| register_buffer( |
| "posterior_mean_coef2", |
| (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), |
| ) |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| snr = alphas_cumprod / (1 - alphas_cumprod) |
| clipped_snr = snr.clone() |
| clipped_snr.clamp_(max=self.snr_clip) |
|
|
| register_buffer("clipped_snr", clipped_snr) |
| register_buffer("snr", snr) |
|
|
| def add_shape_channels(self, x): |
| return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}") |
|
|
| def model_predictions(self, x, t, action_cond=None, current_frame=None, |
| pose_cond=None, mode="training", reference_length=None, frame_idx=None): |
| x = x.permute(1,0,2,3,4) |
| action_cond = action_cond.permute(1,0,2) |
| if pose_cond is not None and pose_cond[0] is not None: |
| try: |
| pose_cond = pose_cond.permute(1,0,2) |
| except: |
| pass |
| t = t.permute(1,0) |
| model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond, |
| mode=mode, reference_length=reference_length, frame_idx=frame_idx) |
| model_output = model_output.permute(1,0,2,3,4) |
| x = x.permute(1,0,2,3,4) |
| t = t.permute(1,0) |
|
|
| if self.objective == "pred_noise": |
| pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise) |
| x_start = self.predict_start_from_noise(x, t, pred_noise) |
|
|
| elif self.objective == "pred_x0": |
| x_start = model_output |
| pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
| elif self.objective == "pred_v": |
| v = model_output |
| x_start = self.predict_start_from_v(x, t, v) |
| pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
|
|
| return ModelPrediction(pred_noise, x_start, model_output) |
|
|
| def predict_start_from_noise(self, x_t, t, noise): |
| return ( |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
| ) |
|
|
| def predict_noise_from_start(self, x_t, t, x0): |
| return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract( |
| self.sqrt_recipm1_alphas_cumprod, t, x_t.shape |
| ) |
|
|
| def predict_v(self, x_start, t, noise): |
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise |
| - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start |
| ) |
|
|
| def predict_start_from_v(self, x_t, t, v): |
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t |
| - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
| ) |
|
|
| def q_mean_variance(self, x_start, t): |
| mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
| variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) |
| log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) |
| return mean, variance, log_variance |
|
|
| def q_posterior(self, x_start, x_t, t): |
| posterior_mean = ( |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| ) |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) |
| posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped |
|
|
| def q_sample(self, x_start, t, noise=None): |
| if noise is None: |
| noise = torch.randn_like(x_start) |
| noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) |
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
| ) |
|
|
| def p_mean_variance(self, x, t, action_cond=None, pose_cond=None, reference_length=None): |
| model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond, |
| pose_cond=pose_cond, reference_length=reference_length, |
| frame_idx=frame_idx) |
| x_start = model_pred.pred_x_start |
| return self.q_posterior(x_start=x_start, x_t=x, t=t) |
|
|
| def compute_loss_weights(self, noise_levels: torch.Tensor): |
|
|
| snr = self.snr[noise_levels] |
| clipped_snr = self.clipped_snr[noise_levels] |
| normalized_clipped_snr = clipped_snr / self.snr_clip |
| normalized_snr = snr / self.snr_clip |
|
|
| if not self.use_fused_snr: |
| |
| match self.objective: |
| case "pred_noise": |
| return clipped_snr / snr |
| case "pred_x0": |
| return clipped_snr |
| case "pred_v": |
| return clipped_snr / (snr + 1) |
|
|
| cum_snr = torch.zeros_like(normalized_snr) |
| for t in range(0, noise_levels.shape[0]): |
| if t == 0: |
| cum_snr[t] = normalized_clipped_snr[t] |
| else: |
| cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t] |
|
|
| cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0) |
| clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr) |
| fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr) |
|
|
| match self.objective: |
| case "pred_noise": |
| return clipped_fused_snr / fused_snr |
| case "pred_x0": |
| return clipped_fused_snr * self.snr_clip |
| case "pred_v": |
| return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1) |
| case _: |
| raise ValueError(f"unknown objective {self.objective}") |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| action_cond: Optional[torch.Tensor], |
| pose_cond, |
| noise_levels: torch.Tensor, |
| reference_length, |
| frame_idx=None |
| ): |
| noise = torch.randn_like(x) |
| noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) |
| |
| noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise) |
|
|
| model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond, |
| pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx) |
|
|
| pred = model_pred.model_out |
| x_pred = model_pred.pred_x_start |
|
|
| if self.objective == "pred_noise": |
| target = noise |
| elif self.objective == "pred_x0": |
| target = x |
| elif self.objective == "pred_v": |
| target = self.predict_v(x, noise_levels, noise) |
| else: |
| raise ValueError(f"unknown objective {self.objective}") |
|
|
| |
| loss = F.mse_loss(pred, target.detach(), reduction="none") |
| loss_weight = self.compute_loss_weights(noise_levels) |
|
|
| loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2))) |
|
|
| loss = loss * loss_weight |
|
|
| return x_pred, loss |
|
|
| def sample_step( |
| self, |
| x: torch.Tensor, |
| action_cond: Optional[torch.Tensor], |
| pose_cond, |
| curr_noise_level: torch.Tensor, |
| next_noise_level: torch.Tensor, |
| guidance_fn: Optional[Callable] = None, |
| current_frame=None, |
| mode="training", |
| reference_length=None, |
| frame_idx=None |
| ): |
| real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long() |
|
|
| |
| curr_noise_level = real_steps[curr_noise_level] |
| next_noise_level = real_steps[next_noise_level] |
|
|
| if self.is_ddim_sampling: |
| return self.ddim_sample_step( |
| x=x, |
| action_cond=action_cond, |
| pose_cond=pose_cond, |
| curr_noise_level=curr_noise_level, |
| next_noise_level=next_noise_level, |
| guidance_fn=guidance_fn, |
| current_frame=current_frame, |
| mode=mode, |
| reference_length=reference_length, |
| frame_idx=frame_idx |
| ) |
|
|
| |
| assert torch.all( |
| (curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1)) |
| ), "Wrong noise level given for ddpm sampling." |
|
|
| assert ( |
| self.sampling_timesteps == self.timesteps |
| ), "sampling_timesteps should be equal to timesteps for ddpm sampling." |
|
|
| return self.ddpm_sample_step( |
| x=x, |
| action_cond=action_cond, |
| pose_cond=pose_cond, |
| curr_noise_level=curr_noise_level, |
| guidance_fn=guidance_fn, |
| reference_length=reference_length, |
| frame_idx=frame_idx |
| ) |
|
|
| def ddpm_sample_step( |
| self, |
| x: torch.Tensor, |
| action_cond: Optional[torch.Tensor], |
| pose_cond, |
| curr_noise_level: torch.Tensor, |
| guidance_fn: Optional[Callable] = None, |
| reference_length=None, |
| frame_idx=None, |
| ): |
| clipped_curr_noise_level = torch.where( |
| curr_noise_level < 0, |
| torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long), |
| curr_noise_level, |
| ) |
|
|
| |
| orig_x = x.clone().detach() |
| scaled_context = self.q_sample( |
| x, |
| clipped_curr_noise_level, |
| noise=torch.zeros_like(x), |
| ) |
| x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x) |
|
|
| if guidance_fn is not None: |
| raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.") |
|
|
| else: |
| model_mean, _, model_log_variance = self.p_mean_variance( |
| x=x, |
| t=clipped_curr_noise_level, |
| action_cond=action_cond, |
| pose_cond=pose_cond, |
| reference_length=reference_length, |
| frame_idx=frame_idx |
| ) |
|
|
| noise = torch.where( |
| self.add_shape_channels(clipped_curr_noise_level > 0), |
| torch.randn_like(x), |
| 0, |
| ) |
| noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) |
| x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise |
|
|
| |
| return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred) |
|
|
| def ddim_sample_step( |
| self, |
| x: torch.Tensor, |
| action_cond: Optional[torch.Tensor], |
| pose_cond, |
| curr_noise_level: torch.Tensor, |
| next_noise_level: torch.Tensor, |
| guidance_fn: Optional[Callable] = None, |
| current_frame=None, |
| mode="training", |
| reference_length=None, |
| frame_idx=None |
| ): |
| |
| clipped_curr_noise_level = torch.where( |
| curr_noise_level < 0, |
| torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long), |
| curr_noise_level, |
| ) |
|
|
| |
| orig_x = x.clone().detach() |
| scaled_context = self.q_sample( |
| x, |
| clipped_curr_noise_level, |
| noise=torch.zeros_like(x), |
| ) |
| x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x) |
|
|
| alpha = self.alphas_cumprod[clipped_curr_noise_level] |
| alpha_next = torch.where( |
| next_noise_level < 0, |
| torch.ones_like(next_noise_level), |
| self.alphas_cumprod[next_noise_level], |
| ) |
| sigma = torch.where( |
| next_noise_level < 0, |
| torch.zeros_like(next_noise_level), |
| self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(), |
| ) |
| c = (1 - alpha_next - sigma**2).sqrt() |
|
|
| alpha_next = self.add_shape_channels(alpha_next) |
| c = self.add_shape_channels(c) |
| sigma = self.add_shape_channels(sigma) |
|
|
| if guidance_fn is not None: |
| with torch.enable_grad(): |
| x = x.detach().requires_grad_() |
|
|
| model_pred = self.model_predictions( |
| x=x, |
| t=clipped_curr_noise_level, |
| action_cond=action_cond, |
| pose_cond=pose_cond, |
| current_frame=current_frame, |
| mode=mode, |
| reference_length=reference_length, |
| frame_idx=frame_idx |
| ) |
|
|
| guidance_loss = guidance_fn(model_pred.pred_x_start) |
| grad = -torch.autograd.grad( |
| guidance_loss, |
| x, |
| )[0] |
|
|
| pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad |
| x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise) |
|
|
| else: |
| |
| model_pred = self.model_predictions( |
| x=x, |
| t=clipped_curr_noise_level, |
| action_cond=action_cond, |
| pose_cond=pose_cond, |
| current_frame=current_frame, |
| mode=mode, |
| reference_length=reference_length, |
| frame_idx=frame_idx |
| ) |
| x_start = model_pred.pred_x_start |
| pred_noise = model_pred.pred_noise |
|
|
| noise = torch.randn_like(x) |
| noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) |
|
|
| x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise |
|
|
| |
| mask = curr_noise_level == next_noise_level |
| x_pred = torch.where( |
| self.add_shape_channels(mask), |
| orig_x, |
| x_pred, |
| ) |
|
|
| return x_pred |
|
|