import copy import os import pickle from pathlib import Path from functools import partial import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import reduce from p_tqdm import p_map from pytorch3d.transforms import (axis_angle_to_quaternion, quaternion_to_axis_angle) from tqdm import tqdm from dataset.quaternion import ax_from_6v, quat_slerp from vis import skeleton_render from vis import SMPLX_Skeleton from dataset.preprocess import My_Normalizer as Normalizer from .utils import extract, make_beta_schedule def identity(t, *args, **kwargs): return t class EMA: def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip( current_model.parameters(), ma_model.parameters() ): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class GaussianDiffusion(nn.Module): def __init__( self, model, opt, horizon, repr_dim, smplx_model, n_timestep=1000, schedule="linear", loss_type="l1", clip_denoised=True, predict_epsilon=True, guidance_weight=3, use_p2=False, cond_drop_prob=0.2, do_normalize=False, ): super().__init__() self.horizon = horizon self.transition_dim = repr_dim self.model = model self.ema = EMA(0.9999) self.master_model = copy.deepcopy(self.model) self.normalizer = None self.do_normalize = do_normalize self.opt = opt self.cond_drop_prob = cond_drop_prob # make a SMPL instance for FK module self.smplx_fk = smplx_model betas = torch.Tensor( make_beta_schedule(schedule=schedule, n_timestep=n_timestep) ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) self.n_timestep = int(n_timestep) self.clip_denoised = clip_denoised self.predict_epsilon = predict_epsilon self.register_buffer("betas", betas) self.register_buffer("alphas_cumprod", alphas_cumprod) self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) self.guidance_weight = guidance_weight # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) self.register_buffer( "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) ) self.register_buffer( "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) ) self.register_buffer( "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) self.register_buffer("posterior_variance", posterior_variance) ## log calculation clipped because the posterior variance ## is 0 at the beginning of the diffusion chain self.register_buffer( "posterior_log_variance_clipped", torch.log(torch.clamp(posterior_variance, min=1e-20)), ) self.register_buffer( "posterior_mean_coef1", betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), ) self.register_buffer( "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), ) # p2 weighting self.p2_loss_weight_k = 1 self.p2_loss_weight_gamma = 0.5 if use_p2 else 0 self.register_buffer( "p2_loss_weight", (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -self.p2_loss_weight_gamma, ) ## get loss coefficients and initialize objective self.loss_fn = F.mse_loss if loss_type == "l2" else F.l1_loss # ------------------------------------------ sampling ------------------------------------------# def predict_start_from_noise(self, x_t, t, noise): """ if self.predict_epsilon, model output is (scaled) noise; otherwise, model predicts x0 directly """ if self.predict_epsilon: return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) else: return noise def predict_noise_from_start(self, x_t, t, x0): return ( (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) def model_predictions(self, x, cond, t, weight=None, clip_x_start = False): weight = weight if weight is not None else self.guidance_weight model_output = self.model.guided_forward(x, cond, t, weight) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity x_start = model_output x_start = maybe_clip(x_start) pred_noise = self.predict_noise_from_start(x, t, x_start) return pred_noise, x_start def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract( self.posterior_log_variance_clipped, t, x_t.shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, cond, t): # guidance clipping if t[0] > 1.0 * self.n_timestep: weight = min(self.guidance_weight, 0) elif t[0] < 0.1 * self.n_timestep: weight = min(self.guidance_weight, 1) else: weight = self.guidance_weight x_recon = self.predict_start_from_noise( x, t=t, noise=self.model.guided_forward(x, cond, t, weight) ) if self.clip_denoised: x_recon.clamp_(-1.0, 1.0) else: assert RuntimeError() model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance, x_recon @torch.no_grad() def p_sample(self, x, cond, t): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance, x_start = self.p_mean_variance( x=x, cond=cond, t=t ) noise = torch.randn_like(model_mean) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape( b, *((1,) * (len(noise.shape) - 1)) ) x_out = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return x_out, x_start @torch.no_grad() def p_sample_loop( self, shape, cond, noise=None, constraint=None, return_diffusion=False, start_point=None, ): device = self.betas.device # default to diffusion over whole timescale start_point = self.n_timestep if start_point is None else start_point batch_size = shape[0] x = torch.randn(shape, device=device) if noise is None else noise.to(device) cond = cond.to(device) if return_diffusion: diffusion = [x] for i in tqdm(reversed(range(0, start_point))): # fill with i timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) x, _ = self.p_sample(x, cond, timesteps) if return_diffusion: diffusion.append(x) if return_diffusion: return x, diffusion else: return x @torch.no_grad() def ddim_sample(self, shape, cond, **kwargs): batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1 times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] x = torch.randn(shape, device = device) cond = cond.to(device) x_start = None for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start = self.clip_denoised) if time_next < 0: x = x_start continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma ** 2).sqrt() noise = torch.randn_like(x) x = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise return x @torch.no_grad() def long_ddim_sample(self, shape, cond, **kwargs): batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1 if batch == 1: return self.ddim_sample(shape, cond) times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps times = list(reversed(times.int().tolist())) weights = np.clip(np.linspace(0, self.guidance_weight * 2, sampling_timesteps), None, self.guidance_weight) time_pairs = list(zip(times[:-1], times[1:], weights)) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] x = torch.randn(shape, device = device) cond = cond.to(device) assert batch > 1 assert x.shape[1] % 2 == 0 half = x.shape[1] // 2 x_start = None for time, time_next, weight in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device=device, dtype=torch.long) pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, weight=weight, clip_x_start = self.clip_denoised) if time_next < 0: x = x_start continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = (1 - alpha_next - sigma ** 2).sqrt() noise = torch.randn_like(x) x = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise if time > 0: # the first half of each sequence is the second half of the previous one x[1:, :half] = x[:-1, half:] return x @torch.no_grad() def inpaint_loop( self, shape, cond, noise=None, constraint=None, return_diffusion=False, start_point=None, ): device = self.betas.device batch_size = shape[0] x = torch.randn(shape, device=device) if noise is None else noise.to(device) cond = cond.to(device) if return_diffusion: diffusion = [x] mask = constraint["mask"].to(device) # batch x horizon x channels value = constraint["value"].to(device) # batch x horizon x channels start_point = self.n_timestep if start_point is None else start_point for i in tqdm(reversed(range(0, start_point))): # fill with i timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) # sample x from step i to step i-1 x, _ = self.p_sample(x, cond, timesteps) # enforce constraint between each denoising step value_ = self.q_sample(value, timesteps - 1) if (i > 0) else x x = value_ * mask + (1.0 - mask) * x if return_diffusion: diffusion.append(x) if return_diffusion: return x, diffusion else: return x @torch.no_grad() def long_inpaint_loop( self, shape, cond, noise=None, constraint=None, return_diffusion=False, start_point=None, ): device = self.betas.device batch_size = shape[0] x = torch.randn(shape, device=device) if noise is None else noise.to(device) cond = cond.to(device) if return_diffusion: diffusion = [x] assert x.shape[1] % 2 == 0 if batch_size == 1: # there's no continuation to do, just do normal return self.p_sample_loop( shape, cond, noise=noise, constraint=constraint, return_diffusion=return_diffusion, start_point=start_point, ) assert batch_size > 1 half = x.shape[1] // 2 start_point = self.n_timestep if start_point is None else start_point for i in tqdm(reversed(range(0, start_point))): # fill with i timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) # sample x from step i to step i-1 x, _ = self.p_sample(x, cond, timesteps) # enforce constraint between each denoising step if i > 0: # the first half of each sequence is the second half of the previous one x[1:, :half] = x[:-1, half:] if return_diffusion: diffusion.append(x) if return_diffusion: return x, diffusion else: return x @torch.no_grad() def conditional_sample( self, shape, cond, constraint=None, *args, horizon=None, **kwargs ): """ conditions : [ (time, state), ... ] """ device = self.betas.device horizon = horizon or self.horizon return self.p_sample_loop(shape, cond, *args, **kwargs) # ------------------------------------------ training ------------------------------------------# def q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sample = ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) return sample def p_losses(self, x_start, cond, t): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 将x0加噪到xt # reconstruct x_recon = self.model(x_noisy, cond, t, cond_drop_prob=self.cond_drop_prob) assert noise.shape == x_recon.shape model_out = x_recon if self.predict_epsilon: target = noise else: target = x_start # full reconstruction loss loss = self.loss_fn(model_out, target, reduction="none") # mse loss loss = reduce(loss, "b ... -> b (...)", "mean") loss = loss * extract(self.p2_loss_weight, t, loss.shape) # split off contact from the rest _, model_out_ = torch.split( model_out, (4, model_out.shape[2] - 4), dim=2 # 前4维是foot contact ) _, target_ = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc # velocity loss target_v = target_[:, 1:] - target_[:, :-1] model_out_v = model_out_[:, 1:] - model_out_[:, :-1] v_loss = self.loss_fn(model_out_v, target_v, reduction="none") v_loss = reduce(v_loss, "b ... -> b (...)", "mean") v_loss = v_loss * extract(self.p2_loss_weight, t, v_loss.shape) # FK loss b, s, c = model_out.shape model_contact, model_out = torch.split(model_out, (4, model_out.shape[2] - 4), dim=2) target_contact, target = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc model_x = model_out[:, :, :3] # root position model_q = ax_from_6v(model_out[:, :, 3:].reshape(b, s, -1, 6)) target_x = target[:, :, :3] target_q = ax_from_6v(target[:, :, 3:].reshape(b, s, -1, 6)) b, s, nums, c_ = model_q.shape if self.opt.nfeats == 139 or self.opt.nfeats==135: model_xp = self.smplx_fk.forward(model_q, model_x) target_xp = self.smplx_fk.forward(target_q, target_x) else: model_q = model_q.view(b*s, -1) target_q = target_q.view(b*s, -1) model_x = model_x.view(-1, 3) target_x = target_x.view(-1, 3) model_xp = self.smplx_fk.forward(model_q, model_x) target_xp = self.smplx_fk.forward(target_q, target_x) model_xp = model_xp.view(b, s, -1, 3) target_xp = target_xp.view(b, s, -1, 3) fk_loss = self.loss_fn(model_xp, target_xp, reduction="none") fk_loss = reduce(fk_loss, "b ... -> b (...)", "mean") fk_loss = fk_loss * extract(self.p2_loss_weight, t, fk_loss.shape) # foot skate loss foot_idx = [7, 8, 10, 11] # find static indices consistent with model's own predictions static_idx = model_contact > 0.95 # N x S x 4 model_feet = model_xp[:, :, foot_idx] # foot positions (N, S, 4, 3) model_foot_v = torch.zeros_like(model_feet) model_foot_v[:, :-1] = ( model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :] ) # (N, S-1, 4, 3) model_foot_v[~static_idx] = 0 foot_loss = self.loss_fn( model_foot_v, torch.zeros_like(model_foot_v), reduction="none" ) foot_loss = reduce(foot_loss, "b ... -> b (...)", "mean") losses = ( 0.636 * loss.mean(), 2.964 * v_loss.mean(), 0.646 * fk_loss.mean(), 10.942 * foot_loss.mean(), ) return sum(losses), losses def loss(self, x, cond, t_override=None): batch_size = len(x) if t_override is None: t = torch.randint(0, self.n_timestep, (batch_size,), device=x.device).long() else: t = torch.full((batch_size,), t_override, device=x.device).long() return self.p_losses(x, cond, t) def forward(self, x, cond, t_override=None): return self.loss(x, cond, t_override) def partial_denoise(self, x, cond, t): x_noisy = self.noise_to_t(x, t) return self.p_sample_loop(x.shape, cond, noise=x_noisy, start_point=t) def noise_to_t(self, x, timestep): batch_size = len(x) t = torch.full((batch_size,), timestep, device=x.device).long() return self.q_sample(x, t) if timestep > 0 else x def smplxmodel_fk(self, local_q, root_pos): # input b, s, nums, c = local_q.shape local_q = local_q.view(b*s, -1) full_pose = self.smplx_model( betas = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32), transl = root_pos.view(b*s, -1), # global translation global_orient = local_q[:, :3], body_pose = local_q[:, 3:66], # 21 jaw_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1 leye_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1 reye_pose= torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1 left_hand_pose = local_q[:, 66:111], # 15 right_hand_pose = local_q[:, 111:], # 15 expression = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32), return_verts = False ) full_pose = full_pose.joints.view(b, s, -1, 3) # b, s, 55, 3 return full_pose def render_sample( self, shape, cond, normalizer, epoch, render_out, fk_out=None, name=None, sound=True, mode="normal", noise=None, constraint=None, sound_folder="ood_sliced", start_point=None, render=True, # do_normalize=True, ): if isinstance(shape, tuple): if mode == "inpaint": func_class = self.inpaint_loop elif mode == "normal": func_class = self.ddim_sample elif mode == "long": func_class = self.long_ddim_sample else: assert False, "Unrecognized inference mode" samples = ( func_class( shape, cond, noise=noise, constraint=constraint, start_point=start_point, ) .detach() .cpu() ) else: samples = shape if self.do_normalize: with torch.no_grad(): samples = normalizer.unnormalize(samples) if samples.shape[2] == 319 or samples.shape[2] == 151 or samples.shape[2] == 139: # debug if samples.shape[2] == 151: sample_contact, samples = torch.split( samples, (4, samples.shape[2] - 4), dim=2 ) else: sample_contact = None # do the FK all at once b, s, c = samples.shape pos = samples[:, :, :3].to(cond.device) # np.zeros((sample.shape[0], 3)) q = samples[:, :, 3:].reshape(b, s, -1, 6) # debug 24 # go 6d to ax q = ax_from_6v(q).to(cond.device) if self.opt.nfeats == 139 or self.opt.nfeats==135: reshape_size = 66 else: reshape_size = 156 if mode == "long": b, s, c1, c2 = q.shape assert s % 2 == 0 half = s // 2 if b > 1: # if long mode, stitch position using linear interp fade_out = torch.ones((1, s, 1)).to(pos.device) fade_in = torch.ones((1, s, 1)).to(pos.device) fade_out[:, half:, :] = torch.linspace(1, 0, half)[None, :, None].to( pos.device ) fade_in[:, :half, :] = torch.linspace(0, 1, half)[None, :, None].to( pos.device ) pos[:-1] *= fade_out pos[1:] *= fade_in full_pos = torch.zeros((s + half * (b - 1), 3)).to(pos.device) idx = 0 for pos_slice in pos: full_pos[idx : idx + s] += pos_slice idx += half # stitch joint angles with slerp slerp_weight = torch.linspace(0, 1, half)[None, :, None].to(pos.device) left, right = q[:-1, half:], q[1:, :half] # convert to quat left, right = ( axis_angle_to_quaternion(left), axis_angle_to_quaternion(right), ) merged = quat_slerp(left, right, slerp_weight) # (b-1) x half x ... # convert back merged = quaternion_to_axis_angle(merged) full_q = torch.zeros((s + half * (b - 1), c1, c2)).to(pos.device) full_q[:half] += q[0, :half] idx = half for q_slice in merged: full_q[idx : idx + half] += q_slice idx += half full_q[idx : idx + half] += q[-1, half:] # unsqueeze for fk full_pos = full_pos.unsqueeze(0) full_q = full_q.unsqueeze(0) else: full_pos = pos full_q = q if fk_out is not None: outname = f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.pkl' # f'{epoch}_{"_".join(name)}.pkl' # Path(fk_out).mkdir(parents=True, exist_ok=True) pickle.dump( { "smpl_poses": full_q.squeeze(0).reshape((-1, reshape_size)).cpu().numpy(), # local rotations "smpl_trans": full_pos.squeeze(0).cpu().numpy(), # root translation # "full_pose": full_pose[0], # 3d positions }, open(os.path.join(fk_out, outname), "wb"), ) return sample_contact = ( sample_contact.detach().cpu().numpy() if sample_contact is not None else None ) def inner(xx): num, pose = xx filename = name[num] if name is not None else None contact = sample_contact[num] if sample_contact is not None else None skeleton_render( pose, epoch=f"e{epoch}_b{num}", out=render_out, name=filename, sound=sound, contact=contact, ) # p_map(inner, enumerate(poses)) # poses: 2, 150, 52, 3 # print("4") if fk_out is not None and mode != "long": Path(fk_out).mkdir(parents=True, exist_ok=True) # for num, (qq, pos_, filename, pose) in enumerate(zip(q, pos, name, poses)): for num, (qq, pos_, filename) in enumerate(zip(q, pos, name)): filename = os.path.basename(filename).split(".")[0] outname = f"{epoch}_{num}_{filename}.pkl" pickle.dump( { "smpl_poses": qq.reshape((-1, reshape_size)).cpu().numpy(), "smpl_trans": pos_.cpu().numpy(), # "full_pose": pose, }, open(f"{fk_out}/{outname}", "wb"), )