DDHpose / common /ddhpose.py
Andyen512
Add model checkpoints and configs
1e45055
import math
import random
from typing import List
from collections import namedtuple
from common.arguments import parse_args
import torch
import torch.nn.functional as F
from torch import nn
from common.mixste_ddhpose import *
__all__ = ["DDHPose"]
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise_dir', 'pred_noise_bone', 'pred_x_start'])
args = parse_args()
boneindextemp = args.boneindex_h36m.split(',')
boneindex = []
for i in range(0,len(boneindextemp),2):
boneindex.append([int(boneindextemp[i]), int(boneindextemp[i+1])])
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def extract(a, t, x_shape):
"""extract the appropriate t index for a batch of indices"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def getbonedirect(seq, boneindex):
bs = seq.size(0)
ss = seq.size(1)
seq = seq.view(-1,seq.size(2),seq.size(3))
bone = []
for index in boneindex:
bone.append(seq[:,index[1]] - seq[:,index[0]])
bonedirect = torch.stack(bone,1)
bonesum = torch.pow(torch.pow(bonedirect,2).sum(2), 0.5).unsqueeze(2)
bonedirect = bonedirect/bonesum
bonedirect = bonedirect.view(bs,ss,-1,3)
return bonedirect
def getbonedirect_test(seq, boneindex):
bone = []
for index in boneindex:
bone.append(seq[:,:,:,index[1]] - seq[:,:,:,index[0]])
bonedirect = torch.stack(bone,3)
bonesum = torch.pow(torch.pow(bonedirect,2).sum(-1), 0.5).unsqueeze(-1)
bonedirect = bonedirect/bonesum
return bonedirect
def getbonelength(seq, boneindex):
bs = seq.size(0)
ss = seq.size(1)
seq = seq.view(-1,seq.size(2),seq.size(3))
bone = []
for index in boneindex:
bone.append(seq[:,index[1]] - seq[:,index[0]])
bone = torch.stack(bone,1)
bone = torch.pow(torch.pow(bone,2).sum(2),0.5)
bone = bone.view(bs,ss, bone.size(1),1)
return bone
def getbonelength_test(seq, boneindex):
bone = []
for index in boneindex:
bone.append(seq[:,:,:,index[1]] - seq[:,:,:,index[0]])
bone = torch.stack(bone,3)
bone = torch.pow(torch.pow(bone,2).sum(-1),0.5).unsqueeze(-1)
return bone
class DDHPose(nn.Module):
"""
Implement DDHPose
"""
def __init__(self, args, joints_left, joints_right, is_train=True, num_proposals=1, sampling_timesteps=1):
super().__init__()
self.frames = args.number_of_frames
self.num_proposals = num_proposals
self.flip = args.test_time_augmentation
self.joints_left = joints_left
self.joints_right = joints_right
self.is_train = is_train
# build diffusion
timesteps = args.timestep
#timesteps_eval = args.timestep_eval
sampling_timesteps = sampling_timesteps
self.objective = 'pred_x0'
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
#self.num_timesteps_eval = int(timesteps_eval)
self.sampling_timesteps = default(sampling_timesteps, timesteps)
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = 1.
self.self_condition = False
self.scale = args.scale
self.box_renewal = True
self.use_ensemble = True
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# Build Dynamic Head.
#self.head = DynamicHead(cfg=cfg, roi_input_shape=self.backbone.output_shape())
drop_path_rate=0
if is_train:
drop_path_rate=0.1
self.dir_bone_estimator = MixSTE2(num_frame=self.frames, num_joints=17, in_chans=2, embed_dim_ratio=args.cs, depth=args.dep,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=drop_path_rate, is_train=is_train)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) /
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def model_predictions_dir_bone(self, x_dir, x_bone, inputs_2d, input_2d_flip, t):
x_t_dir = torch.clamp(x_dir, min=-1.1 * self.scale, max=1.1*self.scale)
x_t_dir = x_t_dir / self.scale
x_t_bone = torch.clamp(x_bone, min=-1.1 * self.scale, max=1.1*self.scale)
x_t_bone = x_t_bone / self.scale
pred_pose = self.dir_bone_estimator(inputs_2d, x_t_dir, x_t_bone, t)
# input 2d flip
x_t_dir_flip = x_t_dir.clone()
x_t_dir_flip[:, :, :, :, 0] *= -1
x_t_dir_flip[:, :, :, self.joints_left + self.joints_right] = x_t_dir_flip[:, :, :,
self.joints_right + self.joints_left]
x_t_bone_flip = x_t_bone.clone()
x_t_bone_flip[:, :, :, self.joints_left + self.joints_right] = x_t_bone_flip[:, :, :,
self.joints_right + self.joints_left]
pred_pose_flip = self.dir_bone_estimator(input_2d_flip, x_t_dir_flip, x_t_bone_flip, t)
pred_pose_flip[:, :, :, :, 0] *= -1
pred_pose_flip[:, :, :, self.joints_left + self.joints_right] = pred_pose_flip[:, :, :,
self.joints_right + self.joints_left]
pred_pos = (pred_pose + pred_pose_flip) / 2
x_start_dir = getbonedirect_test(pred_pos,boneindex)
x_start_dir = x_start_dir * self.scale
x_start_dir = torch.clamp(x_start_dir, min=-1.1 * self.scale, max=1.1*self.scale)
pred_noise_dir = self.predict_noise_from_start(x_dir[:,:,:,1:,:], t, x_start_dir)
x_start_bone = getbonelength_test(pred_pos,boneindex)
x_start_bone = x_start_bone * self.scale
x_start_bone = torch.clamp(x_start_bone, min=-1.1 * self.scale, max=1.1*self.scale)
pred_noise_bone = self.predict_noise_from_start(x_bone[:,:,:,1:,:], t, x_start_bone)
x_start_pos = pred_pos
x_start_pos = x_start_pos * self.scale
x_start_pos = torch.clamp(x_start_pos, min=-1.1 * self.scale, max=1.1*self.scale)
return ModelPrediction(pred_noise_dir, pred_noise_bone, x_start_pos)
def ddim_sample_bone_dir(self, inputs_2d, inputs_3d, clip_denoised=True, do_postprocess=True, input_2d_flip=None):
batch = inputs_2d.shape[0]
jt_num = inputs_2d.shape[-2]
dir_shape = (batch, self.num_proposals, self.frames, jt_num, 3)
bone_shape = (batch, self.num_proposals, self.frames, jt_num, 1)
total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
# [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img_dir = torch.randn(dir_shape, device='cuda')
img_bone = torch.randn(bone_shape, device='cuda')
x_start_dir = None
x_start_bone = None
preds_all_pos = []
for time, time_next in time_pairs:
time_cond = torch.full((batch,), time, dtype=torch.long).cuda()
# self_cond = x_start if self.self_condition else None
#print("%d/%d" % (time, total_timesteps))
preds_pos = self.model_predictions_dir_bone(img_dir, img_bone, inputs_2d, input_2d_flip, time_cond)
pred_noise_dir, pred_noise_bone, x_start_pos = preds_pos.pred_noise_dir, preds_pos.pred_noise_bone, preds_pos.pred_x_start
x_start_dir = getbonedirect_test(x_start_pos,boneindex)
x_start_bone = getbonelength_test(x_start_pos,boneindex)
preds_all_pos.append(x_start_pos)
if time_next < 0:
img_dir = x_start_dir
img_bone = x_start_bone
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise_dir = torch.randn_like(x_start_dir)
noise_bone = torch.randn_like(x_start_bone)
img_dir_t = x_start_dir * alpha_next.sqrt() + \
c * pred_noise_dir + \
sigma * noise_dir
img_bone_t = x_start_bone * alpha_next.sqrt() + \
c * pred_noise_bone + \
sigma * noise_bone
img_dir[:,:,:,1:] = img_dir_t
img_bone[:,:,:,1:] = img_bone_t
return torch.stack(preds_all_pos, dim=1)
# forward diffusion
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def forward(self, input_2d, input_3d, input_2d_flip=None):
# Prepare Proposals.
if not self.is_train:
pred_pose = self.ddim_sample_bone_dir(input_2d, input_3d, input_2d_flip=input_2d_flip)
return pred_pose
if self.is_train:
x_dir, dir_noises, x_bone_length, bone_length_noises, t = self.prepare_targets(input_3d)
x_dir = x_dir.float()
x_bone_length = x_bone_length.float()
t = t.squeeze(-1)
pred_pose = self.dir_bone_estimator(input_2d, x_dir, x_bone_length, t)
return pred_pose
def prepare_diffusion_concat(self, pose_3d):
t = torch.randint(0, self.num_timesteps, (1,), device='cuda').long()
noise = torch.randn(self.frames, pose_3d.shape[1], pose_3d.shape[2], device='cuda')
x_start = pose_3d
x_start = x_start * self.scale
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise)
x = torch.clamp(x, min= -1.1 * self.scale, max= 1.1*self.scale)
x = x / self.scale
return x, noise, t
def prepare_diffusion_bone_dir(self, dir, bone):
t = torch.randint(0, self.num_timesteps, (1,), device='cuda').long()
noise_dir = torch.randn(self.frames, dir.shape[1], dir.shape[2], device='cuda')
noise_bone = torch.randn(self.frames, bone.shape[1], bone.shape[2], device='cuda')
x_start_dir = dir
x_start_bone = bone
x_start_dir = x_start_dir * self.scale
x_start_bone = x_start_bone * self.scale
# noise sample
x_dir = self.q_sample(x_start=x_start_dir, t=t, noise=noise_dir)
x_bone = self.q_sample(x_start=x_start_bone, t=t, noise=noise_bone)
x_dir = torch.clamp(x_dir, min= -1.1 * self.scale, max= 1.1*self.scale)
x_dir = x_dir / self.scale
x_bone = torch.clamp(x_bone, min= -1.1 * self.scale, max= 1.1*self.scale)
x_bone = x_bone / self.scale
return x_dir, noise_dir, x_bone, noise_bone, t
def prepare_targets(self, targets):
diffused_dir = []
noises_dir = []
diffused_bone_length = []
noises_bone_length = []
ts = []
targets_dir = torch.zeros(targets.shape[0],targets.shape[1],targets.shape[2],3).cuda()
targets_bone_length = torch.zeros(targets.shape[0],targets.shape[1],targets.shape[2],1).cuda()
dir = getbonedirect(targets,boneindex)
bone_length = getbonelength(targets,boneindex)
targets_dir[:,:,1:] = dir
targets_bone_length[:,:,1:] = bone_length
for i in range(0,targets.shape[0]):
targets_per_sample_dir = targets_dir[i]
targets_per_sample_bone_length = targets_bone_length[i]
d_dir, d_noise_dir, d_bone_length, d_noise_bone_length, d_t = self.prepare_diffusion_bone_dir(targets_per_sample_dir, targets_per_sample_bone_length)
diffused_dir.append(d_dir)
noises_dir.append(d_noise_dir)
diffused_bone_length.append(d_bone_length)
noises_bone_length.append(d_noise_bone_length)
ts.append(d_t)
return torch.stack(diffused_dir), torch.stack(noises_dir), \
torch.stack(diffused_bone_length), torch.stack(noises_bone_length), torch.stack(ts)