Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | |
| import math | |
| import copy | |
| from pathlib import Path | |
| from random import random | |
| from functools import partial | |
| from collections import namedtuple | |
| from multiprocessing import cpu_count | |
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.optim import Adam | |
| from torchvision import transforms as T, utils | |
| from einops import rearrange, reduce | |
| from einops.layers.torch import Rearrange | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| # constants | |
| ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"]) | |
| # helpers functions | |
| def exists(x): | |
| return x is not None | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if callable(d) else d | |
| def extract(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| def linear_beta_schedule(timesteps): | |
| scale = 1000 / timesteps | |
| beta_start = scale * 0.0001 | |
| beta_end = scale * 0.02 | |
| return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) | |
| def cosine_beta_schedule(timesteps, s=0.008): | |
| """ | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| """ | |
| steps = timesteps + 1 | |
| x = torch.linspace(0, timesteps, steps, dtype=torch.float64) | |
| alphas_cumprod = ( | |
| torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 | |
| ) | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| return torch.clip(betas, 0, 0.999) | |
| class GaussianDiffusion(nn.Module): | |
| def __init__( | |
| self, | |
| timesteps=100, | |
| sampling_timesteps=None, | |
| beta_1=0.0001, | |
| beta_T=0.1, | |
| loss_type="l1", | |
| objective="pred_noise", | |
| beta_schedule="custom", | |
| p2_loss_weight_gamma=0.0, | |
| p2_loss_weight_k=1, | |
| ): | |
| super().__init__() | |
| self.objective = objective | |
| assert objective in { | |
| "pred_noise", | |
| "pred_x0", | |
| }, "objective must be either pred_noise (predict noise) \ | |
| or pred_x0 (predict image start)" | |
| self.timesteps = timesteps | |
| self.sampling_timesteps = sampling_timesteps | |
| self.beta_1 = beta_1 | |
| self.beta_T = beta_T | |
| self.loss_type = loss_type | |
| self.objective = objective | |
| self.beta_schedule = beta_schedule | |
| self.p2_loss_weight_gamma = p2_loss_weight_gamma | |
| self.p2_loss_weight_k = p2_loss_weight_k | |
| self.init_diff_hyper( | |
| self.timesteps, | |
| self.sampling_timesteps, | |
| self.beta_1, | |
| self.beta_T, | |
| self.loss_type, | |
| self.objective, | |
| self.beta_schedule, | |
| self.p2_loss_weight_gamma, | |
| self.p2_loss_weight_k, | |
| ) | |
| def init_diff_hyper( | |
| self, | |
| timesteps, | |
| sampling_timesteps, | |
| beta_1, | |
| beta_T, | |
| loss_type, | |
| objective, | |
| beta_schedule, | |
| p2_loss_weight_gamma, | |
| p2_loss_weight_k, | |
| ): | |
| if beta_schedule == "linear": | |
| betas = linear_beta_schedule(timesteps) | |
| elif beta_schedule == "cosine": | |
| betas = cosine_beta_schedule(timesteps) | |
| elif beta_schedule == "custom": | |
| betas = torch.linspace( | |
| beta_1, beta_T, timesteps, dtype=torch.float64 | |
| ) | |
| else: | |
| raise ValueError(f"unknown beta schedule {beta_schedule}") | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
| (timesteps,) = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.loss_type = loss_type | |
| # sampling related parameters | |
| self.sampling_timesteps = default( | |
| sampling_timesteps, timesteps | |
| ) # default num sampling timesteps to number of timesteps at training | |
| assert self.sampling_timesteps <= timesteps | |
| # helper function to register buffer from float64 to float32 | |
| register_buffer = lambda name, val: self.register_buffer( | |
| name, val.to(torch.float32) | |
| ) | |
| register_buffer("betas", betas) | |
| register_buffer("alphas_cumprod", alphas_cumprod) | |
| register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) | |
| register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) | |
| ) | |
| register_buffer( | |
| "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) | |
| ) | |
| register_buffer( | |
| "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) | |
| ) | |
| register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
| ) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = ( | |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ) | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| register_buffer("posterior_variance", posterior_variance) | |
| # below: log calculation clipped because the posterior variance is 0 | |
| # at the beginning of the diffusion chain | |
| register_buffer( | |
| "posterior_log_variance_clipped", | |
| torch.log(posterior_variance.clamp(min=1e-20)), | |
| ) | |
| register_buffer( | |
| "posterior_mean_coef1", | |
| betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), | |
| ) | |
| register_buffer( | |
| "posterior_mean_coef2", | |
| (1.0 - alphas_cumprod_prev) | |
| * torch.sqrt(alphas) | |
| / (1.0 - alphas_cumprod), | |
| ) | |
| # calculate p2 reweighting | |
| register_buffer( | |
| "p2_loss_weight", | |
| (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) | |
| ** -p2_loss_weight_gamma, | |
| ) | |
| # helper functions | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| ) | |
| def predict_noise_from_start(self, x_t, t, x0): | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0 | |
| ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract( | |
| self.posterior_log_variance_clipped, t, x_t.shape | |
| ) | |
| return ( | |
| posterior_mean, | |
| posterior_variance, | |
| posterior_log_variance_clipped, | |
| ) | |
| def q_sample(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| * noise | |
| ) | |
| def model_predictions(self, x, t, z, x_self_cond=None): | |
| model_output = self.model(x, t, z) | |
| if self.objective == "pred_noise": | |
| pred_noise = model_output | |
| x_start = self.predict_start_from_noise(x, t, model_output) | |
| elif self.objective == "pred_x0": | |
| pred_noise = self.predict_noise_from_start(x, t, model_output) | |
| x_start = model_output | |
| return ModelPrediction(pred_noise, x_start) | |
| def p_mean_variance( | |
| self, | |
| x: torch.Tensor, # B x N_x x dim | |
| t: int, | |
| z: torch.Tensor, | |
| x_self_cond=None, | |
| clip_denoised=False, | |
| ): | |
| preds = self.model_predictions(x, t, z) | |
| x_start = preds.pred_x_start | |
| if clip_denoised: | |
| raise NotImplementedError( | |
| "We don't clip the output because \ | |
| pose does not have a clear bound." | |
| ) | |
| ( | |
| model_mean, | |
| posterior_variance, | |
| posterior_log_variance, | |
| ) = self.q_posterior(x_start=x_start, x_t=x, t=t) | |
| return model_mean, posterior_variance, posterior_log_variance, x_start | |
| def p_sample( | |
| self, | |
| x: torch.Tensor, # B x N_x x dim | |
| t: int, | |
| z: torch.Tensor, | |
| x_self_cond=None, | |
| clip_denoised=False, | |
| cond_fn=None, | |
| cond_start_step=0, | |
| ): | |
| b, *_, device = *x.shape, x.device | |
| batched_times = torch.full( | |
| (x.shape[0],), t, device=x.device, dtype=torch.long | |
| ) | |
| model_mean, _, model_log_variance, x_start = self.p_mean_variance( | |
| x=x, | |
| t=batched_times, | |
| z=z, | |
| x_self_cond=x_self_cond, | |
| clip_denoised=clip_denoised, | |
| ) | |
| if cond_fn is not None and t < cond_start_step: | |
| model_mean = cond_fn(model_mean, t) | |
| noise = 0.0 | |
| else: | |
| noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0 | |
| pred = model_mean + (0.5 * model_log_variance).exp() * noise | |
| return pred, x_start | |
| def p_sample_loop( | |
| self, | |
| shape, | |
| z: torch.Tensor, | |
| cond_fn=None, | |
| cond_start_step=0, | |
| ): | |
| batch, device = shape[0], self.betas.device | |
| # Init here | |
| pose = torch.randn(shape, device=device) | |
| x_start = None | |
| pose_process = [] | |
| pose_process.append(pose.unsqueeze(0)) | |
| for t in reversed(range(0, self.num_timesteps)): | |
| pose, _ = self.p_sample( | |
| x=pose, | |
| t=t, | |
| z=z, | |
| cond_fn=cond_fn, | |
| cond_start_step=cond_start_step, | |
| ) | |
| pose_process.append(pose.unsqueeze(0)) | |
| return pose, torch.cat(pose_process) | |
| def sample(self, shape, z, cond_fn=None, cond_start_step=0): | |
| # TODO: add more variants | |
| sample_fn = self.p_sample_loop | |
| return sample_fn( | |
| shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step | |
| ) | |
| def p_losses( | |
| self, | |
| x_start, | |
| t, | |
| z=None, | |
| noise=None, | |
| ): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| # noise sample | |
| x = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| model_out = self.model(x, t, z) | |
| if self.objective == "pred_noise": | |
| target = noise | |
| x_0_pred = self.predict_start_from_noise(x, t, model_out) | |
| elif self.objective == "pred_x0": | |
| target = x_start | |
| x_0_pred = model_out | |
| else: | |
| raise ValueError(f"unknown objective {self.objective}") | |
| loss = self.loss_fn(model_out, target, reduction="none") | |
| loss = reduce(loss, "b ... -> b (...)", "mean") | |
| loss = loss * extract(self.p2_loss_weight, t, loss.shape) | |
| return { | |
| "loss": loss, | |
| "noise": noise, | |
| "x_0_pred": x_0_pred, | |
| "x_t": x, | |
| "t": t, | |
| } | |
| def forward(self, pose, z=None, *args, **kwargs): | |
| b = len(pose) | |
| t = torch.randint( | |
| 0, self.num_timesteps, (b,), device=pose.device | |
| ).long() | |
| return self.p_losses(pose, t, z=z, *args, **kwargs) | |
| def loss_fn(self): | |
| if self.loss_type == "l1": | |
| return F.l1_loss | |
| elif self.loss_type == "l2": | |
| return F.mse_loss | |
| else: | |
| raise ValueError(f"invalid loss type {self.loss_type}") | |