| | import math |
| | import random |
| | from typing import List |
| | from collections import namedtuple |
| | from common.arguments import parse_args |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from common.mixste_ddhpose import * |
| |
|
| | __all__ = ["DDHPose"] |
| |
|
| | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise_dir', 'pred_noise_bone', 'pred_x_start']) |
| | args = parse_args() |
| | boneindextemp = args.boneindex_h36m.split(',') |
| | boneindex = [] |
| | for i in range(0,len(boneindextemp),2): |
| | boneindex.append([int(boneindextemp[i]), int(boneindextemp[i+1])]) |
| |
|
| | def exists(x): |
| | return x is not None |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | def extract(a, t, x_shape): |
| | """extract the appropriate t index for a batch of indices""" |
| | batch_size = t.shape[0] |
| | out = a.gather(-1, t) |
| | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) |
| |
|
| |
|
| | def cosine_beta_schedule(timesteps, s=0.008): |
| | """ |
| | cosine schedule |
| | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
| | """ |
| | steps = timesteps + 1 |
| | x = torch.linspace(0, timesteps, steps, dtype=torch.float64) |
| | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 |
| | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| | return torch.clip(betas, 0, 0.999) |
| |
|
| | def getbonedirect(seq, boneindex): |
| | bs = seq.size(0) |
| | ss = seq.size(1) |
| | seq = seq.view(-1,seq.size(2),seq.size(3)) |
| | bone = [] |
| | for index in boneindex: |
| | bone.append(seq[:,index[1]] - seq[:,index[0]]) |
| | bonedirect = torch.stack(bone,1) |
| | bonesum = torch.pow(torch.pow(bonedirect,2).sum(2), 0.5).unsqueeze(2) |
| | bonedirect = bonedirect/bonesum |
| | bonedirect = bonedirect.view(bs,ss,-1,3) |
| | return bonedirect |
| |
|
| | def getbonedirect_test(seq, boneindex): |
| | bone = [] |
| | for index in boneindex: |
| | bone.append(seq[:,:,:,index[1]] - seq[:,:,:,index[0]]) |
| | bonedirect = torch.stack(bone,3) |
| | bonesum = torch.pow(torch.pow(bonedirect,2).sum(-1), 0.5).unsqueeze(-1) |
| | bonedirect = bonedirect/bonesum |
| | return bonedirect |
| |
|
| | def getbonelength(seq, boneindex): |
| | bs = seq.size(0) |
| | ss = seq.size(1) |
| | seq = seq.view(-1,seq.size(2),seq.size(3)) |
| | bone = [] |
| | for index in boneindex: |
| | bone.append(seq[:,index[1]] - seq[:,index[0]]) |
| | bone = torch.stack(bone,1) |
| | bone = torch.pow(torch.pow(bone,2).sum(2),0.5) |
| | bone = bone.view(bs,ss, bone.size(1),1) |
| | return bone |
| |
|
| | def getbonelength_test(seq, boneindex): |
| | bone = [] |
| | for index in boneindex: |
| | bone.append(seq[:,:,:,index[1]] - seq[:,:,:,index[0]]) |
| | bone = torch.stack(bone,3) |
| | bone = torch.pow(torch.pow(bone,2).sum(-1),0.5).unsqueeze(-1) |
| |
|
| | return bone |
| |
|
| | class DDHPose(nn.Module): |
| | """ |
| | Implement DDHPose |
| | """ |
| |
|
| | def __init__(self, args, joints_left, joints_right, is_train=True, num_proposals=1, sampling_timesteps=1): |
| | super().__init__() |
| |
|
| | self.frames = args.number_of_frames |
| | self.num_proposals = num_proposals |
| | self.flip = args.test_time_augmentation |
| | self.joints_left = joints_left |
| | self.joints_right = joints_right |
| | self.is_train = is_train |
| |
|
| | |
| | timesteps = args.timestep |
| | |
| | sampling_timesteps = sampling_timesteps |
| | self.objective = 'pred_x0' |
| | betas = cosine_beta_schedule(timesteps) |
| | alphas = 1. - betas |
| | alphas_cumprod = torch.cumprod(alphas, dim=0) |
| | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) |
| | timesteps, = betas.shape |
| | self.num_timesteps = int(timesteps) |
| | |
| |
|
| | self.sampling_timesteps = default(sampling_timesteps, timesteps) |
| | assert self.sampling_timesteps <= timesteps |
| | self.is_ddim_sampling = self.sampling_timesteps < timesteps |
| | self.ddim_sampling_eta = 1. |
| | self.self_condition = False |
| | self.scale = args.scale |
| | self.box_renewal = True |
| | self.use_ensemble = True |
| |
|
| | self.register_buffer('betas', betas) |
| | self.register_buffer('alphas_cumprod', alphas_cumprod) |
| | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
| |
|
| | |
| |
|
| | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) |
| | self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) |
| | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) |
| | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) |
| |
|
| | |
| |
|
| | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
| |
|
| | |
| |
|
| | self.register_buffer('posterior_variance', posterior_variance) |
| |
|
| | |
| |
|
| | self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) |
| | self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
| | self.register_buffer('posterior_mean_coef2', |
| | (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) |
| |
|
| | |
| | |
| | drop_path_rate=0 |
| | if is_train: |
| | drop_path_rate=0.1 |
| |
|
| | self.dir_bone_estimator = MixSTE2(num_frame=self.frames, num_joints=17, in_chans=2, embed_dim_ratio=args.cs, depth=args.dep, |
| | num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=drop_path_rate, is_train=is_train) |
| |
|
| | def predict_noise_from_start(self, x_t, t, x0): |
| | return ( |
| | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / |
| | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
| | ) |
| |
|
| | def model_predictions_dir_bone(self, x_dir, x_bone, inputs_2d, input_2d_flip, t): |
| | x_t_dir = torch.clamp(x_dir, min=-1.1 * self.scale, max=1.1*self.scale) |
| | x_t_dir = x_t_dir / self.scale |
| | x_t_bone = torch.clamp(x_bone, min=-1.1 * self.scale, max=1.1*self.scale) |
| | x_t_bone = x_t_bone / self.scale |
| |
|
| | pred_pose = self.dir_bone_estimator(inputs_2d, x_t_dir, x_t_bone, t) |
| |
|
| | |
| | x_t_dir_flip = x_t_dir.clone() |
| | x_t_dir_flip[:, :, :, :, 0] *= -1 |
| | x_t_dir_flip[:, :, :, self.joints_left + self.joints_right] = x_t_dir_flip[:, :, :, |
| | self.joints_right + self.joints_left] |
| | x_t_bone_flip = x_t_bone.clone() |
| | x_t_bone_flip[:, :, :, self.joints_left + self.joints_right] = x_t_bone_flip[:, :, :, |
| | self.joints_right + self.joints_left] |
| |
|
| | pred_pose_flip = self.dir_bone_estimator(input_2d_flip, x_t_dir_flip, x_t_bone_flip, t) |
| | |
| | pred_pose_flip[:, :, :, :, 0] *= -1 |
| | pred_pose_flip[:, :, :, self.joints_left + self.joints_right] = pred_pose_flip[:, :, :, |
| | self.joints_right + self.joints_left] |
| | pred_pos = (pred_pose + pred_pose_flip) / 2 |
| |
|
| | x_start_dir = getbonedirect_test(pred_pos,boneindex) |
| | x_start_dir = x_start_dir * self.scale |
| | x_start_dir = torch.clamp(x_start_dir, min=-1.1 * self.scale, max=1.1*self.scale) |
| | pred_noise_dir = self.predict_noise_from_start(x_dir[:,:,:,1:,:], t, x_start_dir) |
| |
|
| | x_start_bone = getbonelength_test(pred_pos,boneindex) |
| | x_start_bone = x_start_bone * self.scale |
| | x_start_bone = torch.clamp(x_start_bone, min=-1.1 * self.scale, max=1.1*self.scale) |
| | pred_noise_bone = self.predict_noise_from_start(x_bone[:,:,:,1:,:], t, x_start_bone) |
| |
|
| | x_start_pos = pred_pos |
| | x_start_pos = x_start_pos * self.scale |
| | x_start_pos = torch.clamp(x_start_pos, min=-1.1 * self.scale, max=1.1*self.scale) |
| |
|
| | return ModelPrediction(pred_noise_dir, pred_noise_bone, x_start_pos) |
| |
|
| | def ddim_sample_bone_dir(self, inputs_2d, inputs_3d, clip_denoised=True, do_postprocess=True, input_2d_flip=None): |
| | batch = inputs_2d.shape[0] |
| | jt_num = inputs_2d.shape[-2] |
| |
|
| | dir_shape = (batch, self.num_proposals, self.frames, jt_num, 3) |
| | bone_shape = (batch, self.num_proposals, self.frames, jt_num, 1) |
| | total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective |
| |
|
| | |
| | times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) |
| | times = list(reversed(times.int().tolist())) |
| | time_pairs = list(zip(times[:-1], times[1:])) |
| |
|
| | img_dir = torch.randn(dir_shape, device='cuda') |
| | img_bone = torch.randn(bone_shape, device='cuda') |
| |
|
| | x_start_dir = None |
| | x_start_bone = None |
| |
|
| | preds_all_pos = [] |
| | for time, time_next in time_pairs: |
| | time_cond = torch.full((batch,), time, dtype=torch.long).cuda() |
| | |
| |
|
| | |
| | preds_pos = self.model_predictions_dir_bone(img_dir, img_bone, inputs_2d, input_2d_flip, time_cond) |
| | pred_noise_dir, pred_noise_bone, x_start_pos = preds_pos.pred_noise_dir, preds_pos.pred_noise_bone, preds_pos.pred_x_start |
| |
|
| | x_start_dir = getbonedirect_test(x_start_pos,boneindex) |
| | x_start_bone = getbonelength_test(x_start_pos,boneindex) |
| | |
| | preds_all_pos.append(x_start_pos) |
| |
|
| | if time_next < 0: |
| | img_dir = x_start_dir |
| | img_bone = x_start_bone |
| | continue |
| |
|
| | alpha = self.alphas_cumprod[time] |
| | alpha_next = self.alphas_cumprod[time_next] |
| |
|
| | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() |
| | c = (1 - alpha_next - sigma ** 2).sqrt() |
| |
|
| | noise_dir = torch.randn_like(x_start_dir) |
| | noise_bone = torch.randn_like(x_start_bone) |
| |
|
| | img_dir_t = x_start_dir * alpha_next.sqrt() + \ |
| | c * pred_noise_dir + \ |
| | sigma * noise_dir |
| | img_bone_t = x_start_bone * alpha_next.sqrt() + \ |
| | c * pred_noise_bone + \ |
| | sigma * noise_bone |
| |
|
| | img_dir[:,:,:,1:] = img_dir_t |
| | img_bone[:,:,:,1:] = img_bone_t |
| |
|
| | return torch.stack(preds_all_pos, dim=1) |
| |
|
| | |
| | def q_sample(self, x_start, t, noise=None): |
| | if noise is None: |
| | noise = torch.randn_like(x_start) |
| |
|
| | sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) |
| | sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
| |
|
| | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
| |
|
| | def forward(self, input_2d, input_3d, input_2d_flip=None): |
| |
|
| | |
| | if not self.is_train: |
| | pred_pose = self.ddim_sample_bone_dir(input_2d, input_3d, input_2d_flip=input_2d_flip) |
| | return pred_pose |
| |
|
| | if self.is_train: |
| | x_dir, dir_noises, x_bone_length, bone_length_noises, t = self.prepare_targets(input_3d) |
| | x_dir = x_dir.float() |
| | x_bone_length = x_bone_length.float() |
| |
|
| | t = t.squeeze(-1) |
| |
|
| | pred_pose = self.dir_bone_estimator(input_2d, x_dir, x_bone_length, t) |
| |
|
| | return pred_pose |
| |
|
| |
|
| | def prepare_diffusion_concat(self, pose_3d): |
| |
|
| | t = torch.randint(0, self.num_timesteps, (1,), device='cuda').long() |
| | noise = torch.randn(self.frames, pose_3d.shape[1], pose_3d.shape[2], device='cuda') |
| |
|
| | x_start = pose_3d |
| |
|
| | x_start = x_start * self.scale |
| |
|
| | |
| | x = self.q_sample(x_start=x_start, t=t, noise=noise) |
| |
|
| | x = torch.clamp(x, min= -1.1 * self.scale, max= 1.1*self.scale) |
| | x = x / self.scale |
| |
|
| |
|
| | return x, noise, t |
| |
|
| | def prepare_diffusion_bone_dir(self, dir, bone): |
| |
|
| | t = torch.randint(0, self.num_timesteps, (1,), device='cuda').long() |
| | noise_dir = torch.randn(self.frames, dir.shape[1], dir.shape[2], device='cuda') |
| | noise_bone = torch.randn(self.frames, bone.shape[1], bone.shape[2], device='cuda') |
| |
|
| | x_start_dir = dir |
| | x_start_bone = bone |
| |
|
| | x_start_dir = x_start_dir * self.scale |
| | x_start_bone = x_start_bone * self.scale |
| |
|
| | |
| | x_dir = self.q_sample(x_start=x_start_dir, t=t, noise=noise_dir) |
| | x_bone = self.q_sample(x_start=x_start_bone, t=t, noise=noise_bone) |
| |
|
| | x_dir = torch.clamp(x_dir, min= -1.1 * self.scale, max= 1.1*self.scale) |
| | x_dir = x_dir / self.scale |
| | x_bone = torch.clamp(x_bone, min= -1.1 * self.scale, max= 1.1*self.scale) |
| | x_bone = x_bone / self.scale |
| |
|
| |
|
| | return x_dir, noise_dir, x_bone, noise_bone, t |
| |
|
| | def prepare_targets(self, targets): |
| | diffused_dir = [] |
| | noises_dir = [] |
| | diffused_bone_length = [] |
| | noises_bone_length = [] |
| | ts = [] |
| | |
| | targets_dir = torch.zeros(targets.shape[0],targets.shape[1],targets.shape[2],3).cuda() |
| | targets_bone_length = torch.zeros(targets.shape[0],targets.shape[1],targets.shape[2],1).cuda() |
| | dir = getbonedirect(targets,boneindex) |
| | bone_length = getbonelength(targets,boneindex) |
| | targets_dir[:,:,1:] = dir |
| | targets_bone_length[:,:,1:] = bone_length |
| |
|
| | for i in range(0,targets.shape[0]): |
| | targets_per_sample_dir = targets_dir[i] |
| | targets_per_sample_bone_length = targets_bone_length[i] |
| |
|
| | d_dir, d_noise_dir, d_bone_length, d_noise_bone_length, d_t = self.prepare_diffusion_bone_dir(targets_per_sample_dir, targets_per_sample_bone_length) |
| |
|
| | diffused_dir.append(d_dir) |
| | noises_dir.append(d_noise_dir) |
| |
|
| | diffused_bone_length.append(d_bone_length) |
| | noises_bone_length.append(d_noise_bone_length) |
| | ts.append(d_t) |
| |
|
| | return torch.stack(diffused_dir), torch.stack(noises_dir), \ |
| | torch.stack(diffused_bone_length), torch.stack(noises_bone_length), torch.stack(ts) |
| |
|
| |
|
| |
|