Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import trange | |
| from torchvision.transforms import Compose | |
| class Diffusion(nn.Module): | |
| def __init__( | |
| self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None): | |
| super(Diffusion, self).__init__() | |
| self.nn_backbone = nn_backbone | |
| self.n_timesteps = n_timesteps | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.x_shape = (image_size, image_size) | |
| self.device = device | |
| self.motion_transforms = motion_transforms if motion_transforms else Compose([]) | |
| self.timesteps = torch.arange(n_timesteps) | |
| self.beta = self.get_beta_schedule() | |
| self.set_params() | |
| self.device = device | |
| def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3): | |
| with torch.no_grad(): | |
| n_frames = audio_emb.shape[1] | |
| xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device) | |
| audio_ids = [0] * n_audio_motion_embs | |
| for i in range(n_audio_motion_embs + 1): | |
| audio_ids += [i] | |
| motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)] | |
| motion_frames = torch.cat(motion_frames, dim=1) | |
| samples = [] | |
| for i in trange(n_frames, desc=f'Sampling'): | |
| sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids]) | |
| samples.append(sample_frame.unsqueeze(1)) | |
| motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1) | |
| audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)] | |
| return torch.cat(samples, dim=1) | |
| def sample_loop(self, xT, x_cond, motion_frames, audio_emb): | |
| xt = xT | |
| for i, t in reversed(list(enumerate(self.timesteps))): | |
| timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device) | |
| timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device) | |
| nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb) | |
| mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out) | |
| noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt) | |
| xt = mean + noise * torch.exp(logvar / 2) | |
| return xt | |
| def get_p_params(self, xt, timesteps, nn_out): | |
| if self.in_channels == self.out_channels: | |
| eps_pred = nn_out | |
| p_logvar = self.expand(torch.log(self.beta[timesteps])) | |
| else: | |
| eps_pred, nu = nn_out.chunk(2, 1) | |
| nu = (nu + 1) / 2 | |
| p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps]) | |
| p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred) | |
| return p_mean, p_logvar | |
| def get_q_params(self, xt, timesteps, eps_pred=None, x0=None): | |
| if x0 is None: | |
| # predict x0 from xt and eps_pred | |
| coef1_x0 = self.expand(self.coef1_x0[timesteps]) | |
| coef2_x0 = self.expand(self.coef2_x0[timesteps]) | |
| x0 = coef1_x0 * xt - coef2_x0 * eps_pred | |
| x0 = x0.clamp(-1, 1) | |
| # q(x_{t-1} | x_t, x_0) | |
| coef1_q = self.expand(self.coef1_q[timesteps]) | |
| coef2_q = self.expand(self.coef2_q[timesteps]) | |
| q_mean = coef1_q * x0 + coef2_q * xt | |
| q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps]) | |
| return q_mean, q_logvar | |
| def get_beta_schedule(self, max_beta=0.999): | |
| alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2 | |
| betas = [] | |
| for i in range(self.n_timesteps): | |
| t1 = i / self.n_timesteps | |
| t2 = (i + 1) / self.n_timesteps | |
| betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
| return torch.tensor(betas).float() | |
| def set_params(self): | |
| self.alpha = 1 - self.beta | |
| self.alpha_bar = torch.cumprod(self.alpha, dim=0) | |
| self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]]) | |
| self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar) | |
| self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]])) | |
| # to caluclate x0 from eps_pred | |
| self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar) | |
| self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1) | |
| # for q(x_{t-1} | x_t, x_0) | |
| self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar) | |
| self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar) | |
| def space(self, n_timesteps_new): | |
| # change parameters for spaced timesteps during sampling | |
| self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new) | |
| self.n_timesteps = n_timesteps_new | |
| self.beta = self.get_spaced_beta() | |
| self.set_params() | |
| def space_timesteps(self, n_timesteps, target_timesteps): | |
| all_steps = [] | |
| frac_stride = (n_timesteps - 1) / (target_timesteps - 1) | |
| cur_idx = 0.0 | |
| taken_steps = [] | |
| for _ in range(target_timesteps): | |
| taken_steps.append(round(cur_idx)) | |
| cur_idx += frac_stride | |
| all_steps += taken_steps | |
| return all_steps | |
| def get_spaced_beta(self): | |
| last_alpha_cumprod = 1.0 | |
| new_beta = [] | |
| for i, alpha_cumprod in enumerate(self.alpha_bar): | |
| if i in self.timesteps: | |
| new_beta.append(1 - alpha_cumprod / last_alpha_cumprod) | |
| last_alpha_cumprod = alpha_cumprod | |
| return torch.tensor(new_beta) | |
| def expand(self, arr, dim=4): | |
| while arr.dim() < dim: | |
| arr = arr[:, None] | |
| return arr.to(self.device) |