Spaces:
Runtime error
Runtime error
| # Latent Motion Diffusion Model | |
| import torch | |
| import torch.nn as nn | |
| from .lmdm_modules.model import MotionDecoder | |
| from .lmdm_modules.utils import extract, make_beta_schedule | |
| class LMDM(nn.Module): | |
| def __init__( | |
| self, | |
| motion_feat_dim=265, | |
| audio_feat_dim=1024+35, | |
| seq_frames=80, | |
| checkpoint='', | |
| device='cuda', | |
| clip_denoised=False, # clip denoised (-1,1) | |
| multi_cond_frame=False, | |
| ): | |
| super().__init__() | |
| self.motion_feat_dim = motion_feat_dim | |
| self.audio_feat_dim = audio_feat_dim | |
| self.seq_frames = seq_frames | |
| self.device = device | |
| self.n_timestep = 1000 | |
| self.clip_denoised = clip_denoised | |
| self.guidance_weight = 2 | |
| self.model = MotionDecoder( | |
| nfeats=motion_feat_dim, | |
| seq_len=seq_frames, | |
| latent_dim=512, | |
| ff_size=1024, | |
| num_layers=8, | |
| num_heads=8, | |
| dropout=0.1, | |
| cond_feature_dim=audio_feat_dim, | |
| multi_cond_frame=multi_cond_frame, | |
| ) | |
| self.init_diff() | |
| self.sampling_timesteps = None | |
| def init_diff(self): | |
| n_timestep = self.n_timestep | |
| betas = torch.Tensor( | |
| make_beta_schedule(schedule="cosine", n_timestep=n_timestep) | |
| ) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| self.register_buffer("alphas_cumprod", alphas_cumprod) | |
| self.register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
| ) | |
| self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod))) | |
| def predict_noise_from_start(self, x_t, t, x0): | |
| a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape) | |
| b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| return (a * x_t - x0 / b) | |
| def maybe_clip(self, x): | |
| if self.clip_denoised: | |
| return torch.clamp(x, min=-1., max=1.) | |
| else: | |
| return x | |
| def model_predictions(self, x, cond_frame, cond, t): | |
| weight = self.guidance_weight | |
| x_start = self.model.guided_forward(x, cond_frame, cond, t, weight) | |
| x_start = self.maybe_clip(x_start) | |
| pred_noise = self.predict_noise_from_start(x, t, x_start) | |
| return pred_noise, x_start | |
| def forward(self, x, cond_frame, cond, time_cond): | |
| pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) | |
| return pred_noise, x_start | |
| def load_model(self, ckpt_path): | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| self.eval() | |
| return self | |
| def setup(self, sampling_timesteps=50): | |
| if self.sampling_timesteps == sampling_timesteps: | |
| return | |
| self.sampling_timesteps = sampling_timesteps | |
| total_timesteps = self.n_timestep | |
| device = self.device | |
| eta = 1 | |
| shape = (1, self.seq_frames, self.motion_feat_dim) | |
| times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps | |
| times = list(reversed(times.int().tolist())) | |
| self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] | |
| self.time_cond_list = [] | |
| self.alpha_next_sqrt_list = [] | |
| self.sigma_list = [] | |
| self.c_list = [] | |
| self.noise_list = [] | |
| for time, time_next in self.time_pairs: | |
| time_cond = torch.full((1,), time, device=device, dtype=torch.long) | |
| self.time_cond_list.append(time_cond) | |
| if time_next < 0: | |
| continue | |
| alpha = self.alphas_cumprod[time] | |
| alpha_next = self.alphas_cumprod[time_next] | |
| sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() | |
| c = (1 - alpha_next - sigma ** 2).sqrt() | |
| noise = torch.randn(shape, device=device) | |
| self.alpha_next_sqrt_list.append(alpha_next.sqrt()) | |
| self.sigma_list.append(sigma) | |
| self.c_list.append(c) | |
| self.noise_list.append(noise) | |
| def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps): | |
| self.setup(sampling_timesteps) | |
| cond_frame = kp_cond | |
| cond = aud_cond | |
| shape = (1, self.seq_frames, self.motion_feat_dim) | |
| x = torch.randn(shape, device=self.device) | |
| x_start = None | |
| i = 0 | |
| for _, time_next in self.time_pairs: | |
| time_cond = self.time_cond_list[i] | |
| pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond) | |
| if time_next < 0: | |
| x = x_start | |
| continue | |
| alpha_next_sqrt = self.alpha_next_sqrt_list[i] | |
| c = self.c_list[i] | |
| sigma = self.sigma_list[i] | |
| noise = self.noise_list[i] | |
| x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise | |
| i += 1 | |
| return x # pred_kp_seq | |