Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| def linear_beta_schedule(timesteps): | |
| scale = 1.0 # for 100 steps | |
| beta_start = scale * 0.0001 | |
| beta_end = scale * 0.02 | |
| return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) | |
| class GaussianDiffusion: | |
| def __init__( | |
| self, | |
| device, | |
| fix_mode=False, | |
| text_emb=False, | |
| fixed_frames=2, | |
| seq_len=16, | |
| timesteps=100, | |
| beta_schedule='linear', | |
| ): | |
| self.device = device | |
| self.fix_mode = fix_mode # autoregressive | |
| self.fixed_frames = fixed_frames # number of frames to fix | |
| self.timesteps = timesteps | |
| self.text_emb = text_emb | |
| self.seq_len = seq_len | |
| if beta_schedule == 'linear': | |
| betas = linear_beta_schedule(timesteps) | |
| elif beta_schedule == 'cosine': | |
| raise NotImplementedError('cosine schedule is not implemented yet!') | |
| else: | |
| raise ValueError(f'unknown beta schedule {beta_schedule}') | |
| self.betas = betas.to(self.device) | |
| self.alphas = (1. - self.betas).to(self.device) | |
| self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device) | |
| self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(self.device) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device) | |
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device) | |
| self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod).to(self.device) | |
| self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod).to(self.device) | |
| self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1).to(self.device) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| self.posterior_variance = ( | |
| self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
| ).to(self.device) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning | |
| # of the diffusion chain | |
| self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20)).to(self.device) | |
| self.posterior_mean_coef1 = ( | |
| self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |
| ).to(self.device) | |
| self.posterior_mean_coef2 = ( | |
| (1.0 - self.alphas_cumprod_prev) | |
| * torch.sqrt(self.alphas) | |
| / (1.0 - self.alphas_cumprod) | |
| ).to(self.device) | |
| # get the param of given timestep t | |
| def _extract(self, a, t, x_shape): | |
| batch_size = t.shape[0] | |
| out = a.to(t.device).gather(0, t).float() | |
| out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(self.device) | |
| return out | |
| # forward diffusion (using the nice property): q(x_t | x_0) | |
| def q_sample(self, x_start, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) | |
| sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise | |
| # Get the mean and variance of q(x_t | x_0). | |
| def q_mean_variance(self, x_start, t): | |
| mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape) | |
| log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) | |
| return mean, variance, log_variance | |
| # Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) | |
| def q_posterior_mean_variance(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| # compute x_0 from x_t and pred noise: the reverse of `q_sample` | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| return ( | |
| self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - | |
| self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| ) | |
| # compute predicted mean and variance of p(x_{t-1} | x_t) | |
| def p_mean_variance(self, model, x_t, t, clip_denoised=True, **kwargs): | |
| # predict noise using model | |
| assert 'text' in kwargs, 'text is required' | |
| assert 'prog_ind' in kwargs, 'prog_ind is required' | |
| assert 'joints_orig' in kwargs, 'joints_orig is required' | |
| pred_noise = model(x_t, t, | |
| text_emb=kwargs['text'], | |
| prog_ind=kwargs['prog_ind'], | |
| joints_orig=kwargs['joints_orig']) | |
| # use cfg for text embedding: | |
| if kwargs['use_cfg']: | |
| pred_noise_empty = model(x_t, t, | |
| text_emb=torch.zeros_like(kwargs['text']), | |
| prog_ind=kwargs['prog_ind'], | |
| joints_orig=kwargs['joints_orig']) | |
| pred_noise = pred_noise_empty + kwargs['cfg_alpha'] * (pred_noise - pred_noise_empty) | |
| # get the predicted x_0: different from the algorithm2 in the paper | |
| x_recon = self.predict_start_from_noise(x_t, t, pred_noise) | |
| if clip_denoised: | |
| x_recon = torch.clamp(x_recon, min=-1., max=1.) | |
| model_mean, posterior_variance, posterior_log_variance = \ | |
| self.q_posterior_mean_variance(x_recon, x_t, t) | |
| return model_mean, posterior_variance, posterior_log_variance | |
| # denoise_step: sample x_{t-1} from x_t and pred_noise | |
| # @torch.no_grad() | |
| def p_sample(self, model, x_t, t, clip_denoised=True, **kwargs): | |
| if 'disc_model' in kwargs: | |
| disc_model = kwargs['disc_model'] | |
| try: | |
| cg_alpha = kwargs['cg_alpha'] # default 1.0 | |
| cg_diffusion_steps = kwargs['cg_diffusion_steps'] | |
| except: | |
| print("cg_alpha and cg_diffusion_steps are not provided!") | |
| print("Using default values: cg_alpha=1.0, cg_diffusion_steps=5") | |
| cg_alpha = 1.0 | |
| cg_diffusion_steps = 5 | |
| # predict mean and variance | |
| model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, | |
| clip_denoised=clip_denoised, **kwargs) | |
| model_mean = torch.tensor(model_mean, requires_grad=True) | |
| noise = torch.randn_like(x_t) | |
| # no noise when t == 0 | |
| nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) | |
| if t.item() < cg_diffusion_steps: | |
| pred_syn = disc_model(model_mean, t) # y = f(theta, x) theta fixed | |
| pred_syn.backward() | |
| grad = model_mean.grad * cg_alpha | |
| model_mean = model_mean - nonzero_mask * (0.5 * model_log_variance).exp() * grad | |
| # compute x_{t-1} | |
| pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise | |
| return pred_img | |
| # denoise: reverse diffusion | |
| # @torch.no_grad() | |
| def p_sample_loop(self, model, shape, fixed_points=None, **kwargs): | |
| batch_size = shape[0] | |
| device = next(model.parameters()).device | |
| # start from pure noise (for each example in the batch) | |
| img = torch.randn(shape, device=device) | |
| # notice that if we are in fixed mode, we need to fix the first 2 frames | |
| if self.fix_mode: | |
| assert not (fixed_points is None), 'fixed_points is required for fixed mode' | |
| img[:, :self.fixed_frames, :] = fixed_points | |
| imgs = [] | |
| for i in reversed(range(0, self.timesteps)): | |
| img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long), **kwargs) | |
| if self.fix_mode: | |
| img[:, :self.fixed_frames, :] = fixed_points | |
| imgs.append(img) | |
| return imgs | |
| # sample new images | |
| # @torch.no_grad() | |
| def sample(self, model, batch_size=1, seq_len=16, channels=135, | |
| fixed_points=None, clip_denoised=True, **kwargs): | |
| return self.p_sample_loop(model, shape=(batch_size, seq_len, channels), | |
| fixed_points=fixed_points, clip_denoised=clip_denoised, **kwargs) | |
| # compute train losses | |
| def train_losses(self, model, x_start, t, mask=None, **kwargs): | |
| assert not (mask is None and self.fixed_mode), 'mask is required for fixed mode' | |
| if mask is None: | |
| mask = torch.zeros_like(x_start).to(dtype=torch.bool, device=self.device) | |
| mask_inv = torch.logical_not(mask) | |
| # generate random noise | |
| noise = torch.randn_like(x_start).to(device=self.device) | |
| noise[mask] = 0. | |
| # get x_t | |
| x_noisy = self.q_sample(x_start, t, noise=noise) | |
| predicted_noise = model(x_noisy, t, text_emb=kwargs['text'], prog_ind=kwargs['prog_ind'], joints_orig=kwargs['joints_orig']) | |
| loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv]) | |
| return loss |