Fred808 commited on
Commit
8a45a74
·
verified ·
1 Parent(s): af80243

Upload 10 files

Browse files
utils/PYTORCH3D_LICENSE ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD License
2
+
3
+ For PyTorch3D software
4
+
5
+ Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without modification,
8
+ are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name Facebook nor the names of its contributors may be used to
18
+ endorse or promote products derived from this software without specific
19
+ prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
utils/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ SMPL_DATA_PATH = "./body_models/smpl"
4
+
5
+ SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
6
+ SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
7
+ JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
8
+
9
+ ROT_CONVENTION_TO_ROT_NUMBER = {
10
+ 'legacy': 23,
11
+ 'no_hands': 21,
12
+ 'full_hands': 51,
13
+ 'mitten_hands': 33,
14
+ }
15
+
16
+ GENDERS = ['neutral', 'male', 'female']
17
+ NUM_BETAS = 10
utils/dist_util.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import socket
6
+
7
+ import torch as th
8
+ import torch.distributed as dist
9
+
10
+ # Change this to reflect your cluster layout.
11
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
12
+ GPUS_PER_NODE = 8
13
+
14
+ SETUP_RETRY_COUNT = 3
15
+
16
+ used_device = 0
17
+
18
+ def setup_dist(device=0):
19
+ """
20
+ Setup a distributed process group.
21
+ """
22
+ global used_device
23
+ used_device = device
24
+ if dist.is_initialized():
25
+ return
26
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
27
+
28
+ # comm = MPI.COMM_WORLD
29
+ # backend = "gloo" if not th.cuda.is_available() else "nccl"
30
+
31
+ # if backend == "gloo":
32
+ # hostname = "localhost"
33
+ # else:
34
+ # hostname = socket.gethostbyname(socket.getfqdn())
35
+ # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36
+ # os.environ["RANK"] = str(comm.rank)
37
+ # os.environ["WORLD_SIZE"] = str(comm.size)
38
+
39
+ # port = comm.bcast(_find_free_port(), root=used_device)
40
+ # os.environ["MASTER_PORT"] = str(port)
41
+ # dist.init_process_group(backend=backend, init_method="env://")
42
+
43
+
44
+ def dev():
45
+ """
46
+ Get the device to use for torch.distributed.
47
+ """
48
+ global used_device
49
+ if th.cuda.is_available() and used_device>=0:
50
+ return th.device(f"cuda:{used_device}")
51
+ return th.device("cpu")
52
+
53
+
54
+ def load_state_dict(path, **kwargs):
55
+ """
56
+ Load a PyTorch file without redundant fetches across MPI ranks.
57
+ """
58
+ return th.load(path, **kwargs)
59
+
60
+
61
+ def sync_params(params):
62
+ """
63
+ Synchronize a sequence of Tensors across ranks from rank 0.
64
+ """
65
+ for p in params:
66
+ with th.no_grad():
67
+ dist.broadcast(p, 0)
68
+
69
+
70
+ def _find_free_port():
71
+ try:
72
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
73
+ s.bind(("", 0))
74
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
75
+ return s.getsockname()[1]
76
+ finally:
77
+ s.close()
utils/fixseed.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import random
4
+
5
+
6
+ def fixseed(seed):
7
+ torch.backends.cudnn.benchmark = False
8
+ random.seed(seed)
9
+ np.random.seed(seed)
10
+ torch.manual_seed(seed)
11
+
12
+
13
+ # SEED = 10
14
+ # EVALSEED = 0
15
+ # # Provoc warning: not fully functionnal yet
16
+ # # torch.set_deterministic(True)
17
+ # torch.backends.cudnn.benchmark = False
18
+ # fixseed(SEED)
utils/loss_util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusion.nn import mean_flat, sum_flat
2
+ import torch
3
+ import numpy as np
4
+
5
+ def angle_l2(angle1, angle2):
6
+ a = angle1 - angle2
7
+ a = (a + (torch.pi/2)) % torch.pi - (torch.pi/2)
8
+ return a ** 2
9
+
10
+ def diff_l2(a, b):
11
+ return (a - b) ** 2
12
+
13
+ def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True):
14
+ # assuming a.shape == b.shape == bs, J, Jdim, seqlen
15
+ # assuming mask.shape == bs, 1, 1, seqlen
16
+ loss = loss_fn(a, b)
17
+ loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements
18
+ n_entries = a.shape[1]
19
+ if len(a.shape) > 3:
20
+ n_entries *= a.shape[2]
21
+ non_zero_elements = sum_flat(mask)
22
+ if entries_norm:
23
+ # In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed,
24
+ # Otherwise set it to False
25
+ non_zero_elements *= n_entries
26
+ # print('mask', mask.shape)
27
+ # print('non_zero_elements', non_zero_elements)
28
+ # print('loss', loss)
29
+ mse_loss_val = loss / (non_zero_elements + epsilon) # Add epsilon to avoid division by zero
30
+ # print('mse_loss_val', mse_loss_val)
31
+ return mse_loss_val
32
+
33
+
34
+ def masked_goal_l2(pred_goal, ref_goal, cond, all_goal_joint_names):
35
+ all_goal_joint_names_w_traj = np.append(all_goal_joint_names, 'traj')
36
+ target_joint_idx = [[np.where(all_goal_joint_names_w_traj == j)[0][0] for j in sample_joints] for sample_joints in cond['target_joint_names']]
37
+ loc_mask = torch.zeros_like(pred_goal[:,:-1], dtype=torch.bool)
38
+ for sample_idx in range(loc_mask.shape[0]):
39
+ loc_mask[sample_idx, target_joint_idx[sample_idx]] = True
40
+ loc_mask[:, -1, 1] = False # vertical joint of 'traj' is always masked out
41
+ loc_loss = masked_l2(pred_goal[:,:-1], ref_goal[:,:-1], loc_mask, entries_norm=False)
42
+
43
+ heading_loss = masked_l2(pred_goal[:,-1:, :1], ref_goal[:,-1:, :1], cond['is_heading'].unsqueeze(1).unsqueeze(1), loss_fn=angle_l2, entries_norm=False)
44
+
45
+ loss = loc_loss + heading_loss
46
+ return loss
utils/misc.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class WeightedSum(nn.Module):
6
+ def __init__(self, num_rows):
7
+ super(WeightedSum, self).__init__()
8
+ # Initialize learnable weights
9
+ self.weights = nn.Parameter(torch.randn(num_rows))
10
+
11
+ def forward(self, x):
12
+ # Ensure weights are normalized (optional)
13
+ normalized_weights = self.weights / self.weights.sum() # torch.softmax(self.weights, dim=0)
14
+ # Compute the weighted sum of the rows
15
+ weighted_sum = torch.matmul(normalized_weights, x)
16
+ return weighted_sum
17
+
18
+
19
+ def wrapped_getattr(self, name, default=None, wrapped_member_name='model'):
20
+ ''' should be called from wrappers of model classes such as ClassifierFreeSampleModel'''
21
+
22
+ if isinstance(self, torch.nn.Module):
23
+ # for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules]
24
+ # so we activate nn.Module.__getattr__ first.
25
+ # Otherwise, we might encounter an infinite loop
26
+ try:
27
+ attr = torch.nn.Module.__getattr__(self, name)
28
+ except AttributeError:
29
+ wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name)
30
+ attr = getattr(wrapped_member, name, default)
31
+ else:
32
+ # the easy case, where self is not derived from nn.Module
33
+ wrapped_member = getattr(self, wrapped_member_name)
34
+ attr = getattr(wrapped_member, name, default)
35
+ return attr
36
+
37
+
38
+ def to_numpy(tensor):
39
+ if torch.is_tensor(tensor):
40
+ return tensor.cpu().numpy()
41
+ elif type(tensor).__module__ != 'numpy':
42
+ raise ValueError("Cannot convert {} to numpy array".format(
43
+ type(tensor)))
44
+ return tensor
45
+
46
+
47
+ def to_torch(ndarray):
48
+ if type(ndarray).__module__ == 'numpy':
49
+ return torch.from_numpy(ndarray)
50
+ elif not torch.is_tensor(ndarray):
51
+ raise ValueError("Cannot convert {} to torch tensor".format(
52
+ type(ndarray)))
53
+ return ndarray
54
+
55
+
56
+ def cleanexit():
57
+ import sys
58
+ import os
59
+ try:
60
+ sys.exit(0)
61
+ except SystemExit:
62
+ os._exit(0)
63
+
64
+ def load_model_wo_clip(model, state_dict):
65
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
66
+ assert len(unexpected_keys) == 0
67
+ assert all([k.startswith('clip_model.') for k in missing_keys])
68
+
69
+ def freeze_joints(x, joints_to_freeze):
70
+ # Freezes selected joint *rotations* as they appear in the first frame
71
+ # x [bs, [root+n_joints], joint_dim(6), seqlen]
72
+ frozen = x.detach().clone()
73
+ frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1]
74
+ return frozen
utils/model_util.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model.mdm import MDM
3
+ from diffusion import gaussian_diffusion as gd
4
+ from diffusion.respace import SpacedDiffusion, space_timesteps
5
+ from utils.parser_util import get_cond_mode
6
+ from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
7
+
8
+ def load_model_wo_clip(model, state_dict):
9
+ # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST
10
+ # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST
11
+ del state_dict['sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
12
+ del state_dict['embed_timestep.sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
13
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
14
+ assert len(unexpected_keys) == 0
15
+ assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys])
16
+
17
+
18
+ def create_model_and_diffusion(args, data):
19
+ model = MDM(**get_model_args(args, data))
20
+ diffusion = create_gaussian_diffusion(args)
21
+ return model, diffusion
22
+
23
+
24
+ def get_model_args(args, data):
25
+
26
+ # default args
27
+ clip_version = 'ViT-B/32'
28
+ action_emb = 'tensor'
29
+ cond_mode = get_cond_mode(args)
30
+ if hasattr(data.dataset, 'num_actions'):
31
+ num_actions = data.dataset.num_actions
32
+ else:
33
+ num_actions = 1
34
+
35
+ # SMPL defaults
36
+ data_rep = 'rot6d'
37
+ njoints = 25
38
+ nfeats = 6
39
+ all_goal_joint_names = []
40
+
41
+ if args.dataset == 'humanml':
42
+ data_rep = 'hml_vec'
43
+ njoints = 263
44
+ nfeats = 1
45
+ all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
46
+ elif args.dataset == 'kit':
47
+ data_rep = 'hml_vec'
48
+ njoints = 251
49
+ nfeats = 1
50
+
51
+ # Compatibility with old models
52
+ if not hasattr(args, 'pred_len'):
53
+ args.pred_len = 0
54
+ args.context_len = 0
55
+
56
+ emb_policy = args.__dict__.get('emb_policy', 'add')
57
+ multi_target_cond = args.__dict__.get('multi_target_cond', False)
58
+ multi_encoder_type = args.__dict__.get('multi_encoder_type', 'multi')
59
+ target_enc_layers = args.__dict__.get('target_enc_layers', 1)
60
+
61
+ return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
62
+ 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
63
+ 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
64
+ 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
65
+ 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
66
+ 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset,
67
+ 'text_encoder_type': args.text_encoder_type,
68
+ 'pos_embed_max_len': args.pos_embed_max_len, 'mask_frames': args.mask_frames,
69
+ 'pred_len': args.pred_len, 'context_len': args.context_len, 'emb_policy': emb_policy,
70
+ 'all_goal_joint_names': all_goal_joint_names, 'multi_target_cond': multi_target_cond, 'multi_encoder_type': multi_encoder_type, 'target_enc_layers': target_enc_layers,
71
+ }
72
+
73
+
74
+
75
+ def create_gaussian_diffusion(args):
76
+ # default params
77
+ predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
78
+ steps = args.diffusion_steps
79
+ scale_beta = 1. # no scaling
80
+ timestep_respacing = '' # can be used for ddim sampling, we don't use it.
81
+ learn_sigma = False
82
+ rescale_timesteps = False
83
+
84
+ betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
85
+ loss_type = gd.LossType.MSE
86
+
87
+ if not timestep_respacing:
88
+ timestep_respacing = [steps]
89
+
90
+ if hasattr(args, 'lambda_target_loc'):
91
+ lambda_target_loc = args.lambda_target_loc
92
+ else:
93
+ lambda_target_loc = 0.
94
+
95
+ return SpacedDiffusion(
96
+ use_timesteps=space_timesteps(steps, timestep_respacing),
97
+ betas=betas,
98
+ model_mean_type=(
99
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
100
+ ),
101
+ model_var_type=(
102
+ (
103
+ gd.ModelVarType.FIXED_LARGE
104
+ if not args.sigma_small
105
+ else gd.ModelVarType.FIXED_SMALL
106
+ )
107
+ if not learn_sigma
108
+ else gd.ModelVarType.LEARNED_RANGE
109
+ ),
110
+ loss_type=loss_type,
111
+ rescale_timesteps=rescale_timesteps,
112
+ lambda_vel=args.lambda_vel,
113
+ lambda_rcxyz=args.lambda_rcxyz,
114
+ lambda_fc=args.lambda_fc,
115
+ lambda_target_loc=lambda_target_loc,
116
+ )
117
+
118
+ def load_saved_model(model, model_path, use_avg: bool=False): # use_avg_model
119
+ state_dict = torch.load(model_path, map_location='cpu')
120
+ # Use average model when possible
121
+ if use_avg and 'model_avg' in state_dict.keys():
122
+ # if use_avg_model:
123
+ print('loading avg model')
124
+ state_dict = state_dict['model_avg']
125
+ else:
126
+ if 'model' in state_dict:
127
+ print('loading model without avg')
128
+ state_dict = state_dict['model']
129
+ else:
130
+ print('checkpoint has no avg model, loading as usual.')
131
+ load_model_wo_clip(model, state_dict)
132
+ return model
utils/parser_util.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import argparse
3
+ import os
4
+ import json
5
+
6
+
7
+ def parse_and_load_from_model(parser):
8
+ # args according to the loaded model
9
+ # do not try to specify them from cmd line since they will be overwritten
10
+ add_data_options(parser)
11
+ add_model_options(parser)
12
+ add_diffusion_options(parser)
13
+ args = parser.parse_args()
14
+ args_to_overwrite = []
15
+ for group_name in ['dataset', 'model', 'diffusion']:
16
+ args_to_overwrite += get_args_per_group_name(parser, args, group_name)
17
+
18
+ # load args from model
19
+ if args.model_path != '': # if not using external results file
20
+ args = load_args_from_model(args, args_to_overwrite)
21
+
22
+ if args.cond_mask_prob == 0:
23
+ args.guidance_param = 1
24
+
25
+ return apply_rules(args)
26
+
27
+ def load_args_from_model(args, args_to_overwrite):
28
+ model_path = get_model_path_from_args()
29
+ args_path = os.path.join(os.path.dirname(model_path), 'args.json')
30
+ assert os.path.exists(args_path), 'Arguments json file was not found!'
31
+ with open(args_path, 'r') as fr:
32
+ model_args = json.load(fr)
33
+
34
+ for a in args_to_overwrite:
35
+ if a in model_args.keys():
36
+ setattr(args, a, model_args[a])
37
+
38
+ elif 'cond_mode' in model_args: # backward compitability
39
+ unconstrained = (model_args['cond_mode'] == 'no_cond')
40
+ setattr(args, 'unconstrained', unconstrained)
41
+
42
+ else:
43
+ print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a]))
44
+ return args
45
+
46
+ def apply_rules(args):
47
+ # For prefix completion
48
+ if args.pred_len == 0:
49
+ args.pred_len = args.context_len
50
+
51
+ # For target conditioning
52
+ if args.lambda_target_loc > 0.:
53
+ args.multi_target_cond = True
54
+ return args
55
+
56
+
57
+ def get_args_per_group_name(parser, args, group_name):
58
+ for group in parser._action_groups:
59
+ if group.title == group_name:
60
+ group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
61
+ return list(argparse.Namespace(**group_dict).__dict__.keys())
62
+ return ValueError('group_name was not found.')
63
+
64
+ def get_model_path_from_args():
65
+ try:
66
+ dummy_parser = ArgumentParser()
67
+ dummy_parser.add_argument('--model_path')
68
+ dummy_args, _ = dummy_parser.parse_known_args()
69
+ return dummy_args.model_path
70
+ except:
71
+ raise ValueError('model_path argument must be specified.')
72
+
73
+
74
+ def add_base_options(parser):
75
+ group = parser.add_argument_group('base')
76
+ group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.")
77
+ group.add_argument("--device", default=0, type=int, help="Device id to use.")
78
+ group.add_argument("--seed", default=10, type=int, help="For fixing random seed.")
79
+ group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.")
80
+ group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform', 'WandBPlatform'], type=str,
81
+ help="Choose platform to log results. NoPlatform means no logging.")
82
+ group.add_argument("--external_mode", default=False, type=bool, help="For backward cometability, do not change or delete.")
83
+
84
+
85
+ def add_diffusion_options(parser):
86
+ group = parser.add_argument_group('diffusion')
87
+ group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str,
88
+ help="Noise schedule type")
89
+ group.add_argument("--diffusion_steps", default=1000, type=int,
90
+ help="Number of diffusion steps (denoted T in the paper)")
91
+ group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.")
92
+
93
+
94
+ def add_model_options(parser):
95
+ group = parser.add_argument_group('model')
96
+ group.add_argument("--arch", default='trans_enc',
97
+ choices=['trans_enc', 'trans_dec', 'gru'], type=str,
98
+ help="Architecture types as reported in the paper.")
99
+ group.add_argument("--text_encoder_type", default='clip',
100
+ choices=['clip', 'bert'], type=str, help="Text encoder type.")
101
+ group.add_argument("--emb_trans_dec", action='store_true',
102
+ help="For trans_dec architecture only, if true, will inject condition as a class token"
103
+ " (in addition to cross-attention).")
104
+ group.add_argument("--layers", default=8, type=int,
105
+ help="Number of layers.")
106
+ group.add_argument("--latent_dim", default=512, type=int,
107
+ help="Transformer/GRU width.")
108
+ group.add_argument("--cond_mask_prob", default=.1, type=float,
109
+ help="The probability of masking the condition during training."
110
+ " For classifier-free guidance learning.")
111
+ group.add_argument("--mask_frames", action='store_true', help="If true, will fix Rotem's bug and mask invalid frames.")
112
+ group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.")
113
+ group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.")
114
+ group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.")
115
+ group.add_argument("--lambda_target_loc", default=0.0, type=float, help="For HumanML only, when . L2 with target location.")
116
+ group.add_argument("--unconstrained", action='store_true',
117
+ help="Model is trained unconditionally. That is, it is constrained by neither text nor action. "
118
+ "Currently tested on HumanAct12 only.")
119
+ group.add_argument("--pos_embed_max_len", default=5000, type=int,
120
+ help="Pose embedding max length.")
121
+ group.add_argument("--use_ema", action='store_true',
122
+ help="If True, will use EMA model averaging.")
123
+
124
+
125
+ group.add_argument("--multi_target_cond", action='store_true', help="If true, enable multi-target conditioning (aka Sigal's model).")
126
+ group.add_argument("--multi_encoder_type", default='single', choices=['single', 'multi', 'split'], type=str, help="Specifies the encoder type to be used for the multi joint condition.")
127
+ group.add_argument("--target_enc_layers", default=1, type=int, help="Num target encoder layers")
128
+
129
+
130
+ # Prefix completion model
131
+ group.add_argument("--context_len", default=0, type=int, help="If larger than 0, will do prefix completion.")
132
+ group.add_argument("--pred_len", default=0, type=int, help="If context_len larger than 0, will do prefix completion. If pred_len will not be specified - will use the same length as context_len")
133
+
134
+
135
+
136
+
137
+ def add_data_options(parser):
138
+ group = parser.add_argument_group('dataset')
139
+ group.add_argument("--dataset", default='humanml', choices=['humanml', 'kit', 'humanact12', 'uestc'], type=str,
140
+ help="Dataset name (choose from list).")
141
+ group.add_argument("--data_dir", default="", type=str,
142
+ help="If empty, will use defaults according to the specified dataset.")
143
+
144
+
145
+ def add_training_options(parser):
146
+ group = parser.add_argument_group('training')
147
+ group.add_argument("--save_dir", required=True, type=str,
148
+ help="Path to save checkpoints and results.")
149
+ group.add_argument("--overwrite", action='store_true',
150
+ help="If True, will enable to use an already existing save_dir.")
151
+ group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.")
152
+ group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.")
153
+ group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.")
154
+ group.add_argument("--eval_batch_size", default=32, type=int,
155
+ help="Batch size during evaluation loop. Do not change this unless you know what you are doing. "
156
+ "T2m precision calculation is based on fixed batch size 32.")
157
+ group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str,
158
+ help="Which split to evaluate on during training.")
159
+ group.add_argument("--eval_during_training", action='store_true',
160
+ help="If True, will run evaluation during training.")
161
+ group.add_argument("--eval_rep_times", default=3, type=int,
162
+ help="Number of repetitions for evaluation loop during training.")
163
+ group.add_argument("--eval_num_samples", default=1_000, type=int,
164
+ help="If -1, will use all samples in the specified split.")
165
+ group.add_argument("--log_interval", default=1_000, type=int,
166
+ help="Log losses each N steps")
167
+ group.add_argument("--save_interval", default=50_000, type=int,
168
+ help="Save checkpoints and run evaluation each N steps")
169
+ group.add_argument("--num_steps", default=600_000, type=int,
170
+ help="Training will stop after the specified number of steps.")
171
+ group.add_argument("--num_frames", default=60, type=int,
172
+ help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.")
173
+ group.add_argument("--resume_checkpoint", default="", type=str,
174
+ help="If not empty, will start from the specified checkpoint (path to model###.pt file).")
175
+
176
+ group.add_argument("--gen_during_training", action='store_true',
177
+ help="If True, will generate motions during training, on each save interval.")
178
+ group.add_argument("--gen_num_samples", default=3, type=int,
179
+ help="Number of samples to sample while generating")
180
+ group.add_argument("--gen_num_repetitions", default=2, type=int,
181
+ help="Number of repetitions, per sample (text prompt/action)")
182
+ group.add_argument("--gen_guidance_param", default=2.5, type=float,
183
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
184
+
185
+ group.add_argument("--avg_model_beta", default=0.9999, type=float, help="Average model beta (for EMA).")
186
+ group.add_argument("--adam_beta2", default=0.999, type=float, help="Adam beta2.")
187
+
188
+ group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
189
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
190
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
191
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
192
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
193
+
194
+
195
+ def add_sampling_options(parser):
196
+ group = parser.add_argument_group('sampling')
197
+ group.add_argument("--model_path", required=True, type=str,
198
+ help="Path to model####.pt file to be sampled.")
199
+ group.add_argument("--output_dir", default='', type=str,
200
+ help="Path to results dir (auto created by the script). "
201
+ "If empty, will create dir in parallel to checkpoint.")
202
+ group.add_argument("--num_samples", default=6, type=int,
203
+ help="Maximal number of prompts to sample, "
204
+ "if loading dataset from file, this field will be ignored.")
205
+ group.add_argument("--num_repetitions", default=3, type=int,
206
+ help="Number of repetitions, per sample (text prompt/action)")
207
+ group.add_argument("--guidance_param", default=2.5, type=float,
208
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
209
+
210
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
211
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
212
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
213
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
214
+
215
+ def add_generate_options(parser):
216
+ group = parser.add_argument_group('generate')
217
+ group.add_argument("--motion_length", default=6.0, type=float,
218
+ help="The length of the sampled motion [in seconds]. "
219
+ "Maximum is 9.8 for HumanML3D (text-to-motion), and 2.0 for HumanAct12 (action-to-motion)")
220
+ group.add_argument("--input_text", default='', type=str,
221
+ help="Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
222
+ group.add_argument("--dynamic_text_path", default='', type=str,
223
+ help="For the autoregressive mode only! Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
224
+ group.add_argument("--action_file", default='', type=str,
225
+ help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, "
226
+ "or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. "
227
+ "If no file is specified, will take action names from dataset.")
228
+ group.add_argument("--text_prompt", default='', type=str,
229
+ help="A text prompt to be generated. If empty, will take text prompts from dataset.")
230
+ group.add_argument("--action_name", default='', type=str,
231
+ help="An action name to be generated. If empty, will take text prompts from dataset.")
232
+ group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
233
+
234
+
235
+ def add_edit_options(parser):
236
+ group = parser.add_argument_group('edit')
237
+ group.add_argument("--edit_mode", default='in_between', choices=['in_between', 'upper_body'], type=str,
238
+ help="Defines which parts of the input motion will be edited.\n"
239
+ "(1) in_between - suffix and prefix motion taken from input motion, "
240
+ "middle motion is generated.\n"
241
+ "(2) upper_body - lower body joints taken from input motion, "
242
+ "upper body is generated.")
243
+ group.add_argument("--text_condition", default='', type=str,
244
+ help="Editing will be conditioned on this text prompt. "
245
+ "If empty, will perform unconditioned editing.")
246
+ group.add_argument("--prefix_end", default=0.25, type=float,
247
+ help="For in_between editing - Defines the end of input prefix (ratio from all frames).")
248
+ group.add_argument("--suffix_start", default=0.75, type=float,
249
+ help="For in_between editing - Defines the start of input suffix (ratio from all frames).")
250
+
251
+
252
+ def add_evaluation_options(parser):
253
+ group = parser.add_argument_group('eval')
254
+ group.add_argument("--model_path", required=True, type=str,
255
+ help="Path to model####.pt file to be sampled.")
256
+ group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug', 'full'], type=str,
257
+ help="wo_mm (t2m only) - 20 repetitions without multi-modality metric; "
258
+ "mm_short (t2m only) - 5 repetitions with multi-modality metric; "
259
+ "debug - short run, less accurate results."
260
+ "full (a2m only) - 20 repetitions.")
261
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
262
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
263
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
264
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
265
+ group.add_argument("--guidance_param", default=2.5, type=float,
266
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
267
+
268
+
269
+ def get_cond_mode(args):
270
+ if args.unconstrained:
271
+ cond_mode = 'no_cond'
272
+ elif args.dataset in ['kit', 'humanml']:
273
+ cond_mode = 'text'
274
+ else:
275
+ cond_mode = 'action'
276
+ return cond_mode
277
+
278
+
279
+ def train_args():
280
+ parser = ArgumentParser()
281
+ add_base_options(parser)
282
+ add_data_options(parser)
283
+ add_model_options(parser)
284
+ add_diffusion_options(parser)
285
+ add_training_options(parser)
286
+ return apply_rules(parser.parse_args())
287
+
288
+
289
+ def generate_args():
290
+ parser = ArgumentParser()
291
+ # args specified by the user: (all other will be loaded from the model)
292
+ add_base_options(parser)
293
+ add_sampling_options(parser)
294
+ add_generate_options(parser)
295
+ args = parse_and_load_from_model(parser)
296
+ cond_mode = get_cond_mode(args)
297
+
298
+ if (args.input_text or args.text_prompt) and cond_mode != 'text':
299
+ raise Exception('Arguments input_text and text_prompt should not be used for an action condition. Please use action_file or action_name.')
300
+ elif (args.action_file or args.action_name) and cond_mode != 'action':
301
+ raise Exception('Arguments action_file and action_name should not be used for a text condition. Please use input_text or text_prompt.')
302
+
303
+ return args
304
+
305
+
306
+ def edit_args():
307
+ parser = ArgumentParser()
308
+ # args specified by the user: (all other will be loaded from the model)
309
+ add_base_options(parser)
310
+ add_sampling_options(parser)
311
+ add_edit_options(parser)
312
+ return parse_and_load_from_model(parser)
313
+
314
+
315
+ def evaluation_parser():
316
+ parser = ArgumentParser()
317
+ # args specified by the user: (all other will be loaded from the model)
318
+ add_base_options(parser)
319
+ add_evaluation_options(parser)
320
+ return parse_and_load_from_model(parser)
utils/rotation_conversions.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
+ # Check PYTORCH3D_LICENCE before use
4
+
5
+ import functools
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """
13
+ The transformation matrices returned from the functions in this file assume
14
+ the points on which the transformation will be applied are column vectors.
15
+ i.e. the R matrix is structured as
16
+
17
+ R = [
18
+ [Rxx, Rxy, Rxz],
19
+ [Ryx, Ryy, Ryz],
20
+ [Rzx, Rzy, Rzz],
21
+ ] # (3, 3)
22
+
23
+ This matrix can be applied to column vectors by post multiplication
24
+ by the points e.g.
25
+
26
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
27
+ transformed_points = R * points
28
+
29
+ To apply the same matrix to points which are row vectors, the R matrix
30
+ can be transposed and pre multiplied by the points:
31
+
32
+ e.g.
33
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
34
+ transformed_points = points * R.transpose(1, 0)
35
+ """
36
+
37
+
38
+ def quaternion_to_matrix(quaternions):
39
+ """
40
+ Convert rotations given as quaternions to rotation matrices.
41
+
42
+ Args:
43
+ quaternions: quaternions with real part first,
44
+ as tensor of shape (..., 4).
45
+
46
+ Returns:
47
+ Rotation matrices as tensor of shape (..., 3, 3).
48
+ """
49
+ r, i, j, k = torch.unbind(quaternions, -1)
50
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
51
+
52
+ o = torch.stack(
53
+ (
54
+ 1 - two_s * (j * j + k * k),
55
+ two_s * (i * j - k * r),
56
+ two_s * (i * k + j * r),
57
+ two_s * (i * j + k * r),
58
+ 1 - two_s * (i * i + k * k),
59
+ two_s * (j * k - i * r),
60
+ two_s * (i * k - j * r),
61
+ two_s * (j * k + i * r),
62
+ 1 - two_s * (i * i + j * j),
63
+ ),
64
+ -1,
65
+ )
66
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
67
+
68
+
69
+ def _copysign(a, b):
70
+ """
71
+ Return a tensor where each element has the absolute value taken from the,
72
+ corresponding element of a, with sign taken from the corresponding
73
+ element of b. This is like the standard copysign floating-point operation,
74
+ but is not careful about negative 0 and NaN.
75
+
76
+ Args:
77
+ a: source tensor.
78
+ b: tensor whose signs will be used, of the same shape as a.
79
+
80
+ Returns:
81
+ Tensor of the same shape as a with the signs of b.
82
+ """
83
+ signs_differ = (a < 0) != (b < 0)
84
+ return torch.where(signs_differ, -a, a)
85
+
86
+
87
+ def _sqrt_positive_part(x):
88
+ """
89
+ Returns torch.sqrt(torch.max(0, x))
90
+ but with a zero subgradient where x is 0.
91
+ """
92
+ ret = torch.zeros_like(x)
93
+ positive_mask = x > 0
94
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
95
+ return ret
96
+
97
+
98
+ def matrix_to_quaternion(matrix):
99
+ """
100
+ Convert rotations given as rotation matrices to quaternions.
101
+
102
+ Args:
103
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
104
+
105
+ Returns:
106
+ quaternions with real part first, as tensor of shape (..., 4).
107
+ """
108
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
109
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
110
+ m00 = matrix[..., 0, 0]
111
+ m11 = matrix[..., 1, 1]
112
+ m22 = matrix[..., 2, 2]
113
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
114
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
115
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
116
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
117
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
118
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
119
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
120
+ return torch.stack((o0, o1, o2, o3), -1)
121
+
122
+
123
+ def _axis_angle_rotation(axis: str, angle):
124
+ """
125
+ Return the rotation matrices for one of the rotations about an axis
126
+ of which Euler angles describe, for each value of the angle given.
127
+
128
+ Args:
129
+ axis: Axis label "X" or "Y or "Z".
130
+ angle: any shape tensor of Euler angles in radians
131
+
132
+ Returns:
133
+ Rotation matrices as tensor of shape (..., 3, 3).
134
+ """
135
+
136
+ cos = torch.cos(angle)
137
+ sin = torch.sin(angle)
138
+ one = torch.ones_like(angle)
139
+ zero = torch.zeros_like(angle)
140
+
141
+ if axis == "X":
142
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
143
+ if axis == "Y":
144
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
145
+ if axis == "Z":
146
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
147
+
148
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
149
+
150
+
151
+ def euler_angles_to_matrix(euler_angles, convention: str):
152
+ """
153
+ Convert rotations given as Euler angles in radians to rotation matrices.
154
+
155
+ Args:
156
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
157
+ convention: Convention string of three uppercase letters from
158
+ {"X", "Y", and "Z"}.
159
+
160
+ Returns:
161
+ Rotation matrices as tensor of shape (..., 3, 3).
162
+ """
163
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
164
+ raise ValueError("Invalid input euler angles.")
165
+ if len(convention) != 3:
166
+ raise ValueError("Convention must have 3 letters.")
167
+ if convention[1] in (convention[0], convention[2]):
168
+ raise ValueError(f"Invalid convention {convention}.")
169
+ for letter in convention:
170
+ if letter not in ("X", "Y", "Z"):
171
+ raise ValueError(f"Invalid letter {letter} in convention string.")
172
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
173
+ return functools.reduce(torch.matmul, matrices)
174
+
175
+
176
+ def _angle_from_tan(
177
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
178
+ ):
179
+ """
180
+ Extract the first or third Euler angle from the two members of
181
+ the matrix which are positive constant times its sine and cosine.
182
+
183
+ Args:
184
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
185
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
186
+ convention.
187
+ data: Rotation matrices as tensor of shape (..., 3, 3).
188
+ horizontal: Whether we are looking for the angle for the third axis,
189
+ which means the relevant entries are in the same row of the
190
+ rotation matrix. If not, they are in the same column.
191
+ tait_bryan: Whether the first and third axes in the convention differ.
192
+
193
+ Returns:
194
+ Euler Angles in radians for each matrix in dataset as a tensor
195
+ of shape (...).
196
+ """
197
+
198
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
199
+ if horizontal:
200
+ i2, i1 = i1, i2
201
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
202
+ if horizontal == even:
203
+ return torch.atan2(data[..., i1], data[..., i2])
204
+ if tait_bryan:
205
+ return torch.atan2(-data[..., i2], data[..., i1])
206
+ return torch.atan2(data[..., i2], -data[..., i1])
207
+
208
+
209
+ def _index_from_letter(letter: str):
210
+ if letter == "X":
211
+ return 0
212
+ if letter == "Y":
213
+ return 1
214
+ if letter == "Z":
215
+ return 2
216
+
217
+
218
+ def matrix_to_euler_angles(matrix, convention: str):
219
+ """
220
+ Convert rotations given as rotation matrices to Euler angles in radians.
221
+
222
+ Args:
223
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
224
+ convention: Convention string of three uppercase letters.
225
+
226
+ Returns:
227
+ Euler angles in radians as tensor of shape (..., 3).
228
+ """
229
+ if len(convention) != 3:
230
+ raise ValueError("Convention must have 3 letters.")
231
+ if convention[1] in (convention[0], convention[2]):
232
+ raise ValueError(f"Invalid convention {convention}.")
233
+ for letter in convention:
234
+ if letter not in ("X", "Y", "Z"):
235
+ raise ValueError(f"Invalid letter {letter} in convention string.")
236
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
237
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
238
+ i0 = _index_from_letter(convention[0])
239
+ i2 = _index_from_letter(convention[2])
240
+ tait_bryan = i0 != i2
241
+ if tait_bryan:
242
+ central_angle = torch.asin(
243
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
244
+ )
245
+ else:
246
+ central_angle = torch.acos(matrix[..., i0, i0])
247
+
248
+ o = (
249
+ _angle_from_tan(
250
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
251
+ ),
252
+ central_angle,
253
+ _angle_from_tan(
254
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
255
+ ),
256
+ )
257
+ return torch.stack(o, -1)
258
+
259
+
260
+ def random_quaternions(
261
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
262
+ ):
263
+ """
264
+ Generate random quaternions representing rotations,
265
+ i.e. versors with nonnegative real part.
266
+
267
+ Args:
268
+ n: Number of quaternions in a batch to return.
269
+ dtype: Type to return.
270
+ device: Desired device of returned tensor. Default:
271
+ uses the current device for the default tensor type.
272
+ requires_grad: Whether the resulting tensor should have the gradient
273
+ flag set.
274
+
275
+ Returns:
276
+ Quaternions as tensor of shape (N, 4).
277
+ """
278
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
279
+ s = (o * o).sum(1)
280
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
281
+ return o
282
+
283
+
284
+ def random_rotations(
285
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
286
+ ):
287
+ """
288
+ Generate random rotations as 3x3 rotation matrices.
289
+
290
+ Args:
291
+ n: Number of rotation matrices in a batch to return.
292
+ dtype: Type to return.
293
+ device: Device of returned tensor. Default: if None,
294
+ uses the current device for the default tensor type.
295
+ requires_grad: Whether the resulting tensor should have the gradient
296
+ flag set.
297
+
298
+ Returns:
299
+ Rotation matrices as tensor of shape (n, 3, 3).
300
+ """
301
+ quaternions = random_quaternions(
302
+ n, dtype=dtype, device=device, requires_grad=requires_grad
303
+ )
304
+ return quaternion_to_matrix(quaternions)
305
+
306
+
307
+ def random_rotation(
308
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
309
+ ):
310
+ """
311
+ Generate a single random 3x3 rotation matrix.
312
+
313
+ Args:
314
+ dtype: Type to return
315
+ device: Device of returned tensor. Default: if None,
316
+ uses the current device for the default tensor type
317
+ requires_grad: Whether the resulting tensor should have the gradient
318
+ flag set
319
+
320
+ Returns:
321
+ Rotation matrix as tensor of shape (3, 3).
322
+ """
323
+ return random_rotations(1, dtype, device, requires_grad)[0]
324
+
325
+
326
+ def standardize_quaternion(quaternions):
327
+ """
328
+ Convert a unit quaternion to a standard form: one in which the real
329
+ part is non negative.
330
+
331
+ Args:
332
+ quaternions: Quaternions with real part first,
333
+ as tensor of shape (..., 4).
334
+
335
+ Returns:
336
+ Standardized quaternions as tensor of shape (..., 4).
337
+ """
338
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
339
+
340
+
341
+ def quaternion_raw_multiply(a, b):
342
+ """
343
+ Multiply two quaternions.
344
+ Usual torch rules for broadcasting apply.
345
+
346
+ Args:
347
+ a: Quaternions as tensor of shape (..., 4), real part first.
348
+ b: Quaternions as tensor of shape (..., 4), real part first.
349
+
350
+ Returns:
351
+ The product of a and b, a tensor of quaternions shape (..., 4).
352
+ """
353
+ aw, ax, ay, az = torch.unbind(a, -1)
354
+ bw, bx, by, bz = torch.unbind(b, -1)
355
+ ow = aw * bw - ax * bx - ay * by - az * bz
356
+ ox = aw * bx + ax * bw + ay * bz - az * by
357
+ oy = aw * by - ax * bz + ay * bw + az * bx
358
+ oz = aw * bz + ax * by - ay * bx + az * bw
359
+ return torch.stack((ow, ox, oy, oz), -1)
360
+
361
+
362
+ def quaternion_multiply(a, b):
363
+ """
364
+ Multiply two quaternions representing rotations, returning the quaternion
365
+ representing their composition, i.e. the versor with nonnegative real part.
366
+ Usual torch rules for broadcasting apply.
367
+
368
+ Args:
369
+ a: Quaternions as tensor of shape (..., 4), real part first.
370
+ b: Quaternions as tensor of shape (..., 4), real part first.
371
+
372
+ Returns:
373
+ The product of a and b, a tensor of quaternions of shape (..., 4).
374
+ """
375
+ ab = quaternion_raw_multiply(a, b)
376
+ return standardize_quaternion(ab)
377
+
378
+
379
+ def quaternion_invert(quaternion):
380
+ """
381
+ Given a quaternion representing rotation, get the quaternion representing
382
+ its inverse.
383
+
384
+ Args:
385
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
386
+ first, which must be versors (unit quaternions).
387
+
388
+ Returns:
389
+ The inverse, a tensor of quaternions of shape (..., 4).
390
+ """
391
+
392
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
393
+
394
+
395
+ def quaternion_apply(quaternion, point):
396
+ """
397
+ Apply the rotation given by a quaternion to a 3D point.
398
+ Usual torch rules for broadcasting apply.
399
+
400
+ Args:
401
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
402
+ point: Tensor of 3D points of shape (..., 3).
403
+
404
+ Returns:
405
+ Tensor of rotated points of shape (..., 3).
406
+ """
407
+ if point.size(-1) != 3:
408
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
409
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
410
+ point_as_quaternion = torch.cat((real_parts, point), -1)
411
+ out = quaternion_raw_multiply(
412
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
413
+ quaternion_invert(quaternion),
414
+ )
415
+ return out[..., 1:]
416
+
417
+
418
+ def axis_angle_to_matrix(axis_angle):
419
+ """
420
+ Convert rotations given as axis/angle to rotation matrices.
421
+
422
+ Args:
423
+ axis_angle: Rotations given as a vector in axis angle form,
424
+ as a tensor of shape (..., 3), where the magnitude is
425
+ the angle turned anticlockwise in radians around the
426
+ vector's direction.
427
+
428
+ Returns:
429
+ Rotation matrices as tensor of shape (..., 3, 3).
430
+ """
431
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
432
+
433
+
434
+ def matrix_to_axis_angle(matrix):
435
+ """
436
+ Convert rotations given as rotation matrices to axis/angle.
437
+
438
+ Args:
439
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
440
+
441
+ Returns:
442
+ Rotations given as a vector in axis angle form, as a tensor
443
+ of shape (..., 3), where the magnitude is the angle
444
+ turned anticlockwise in radians around the vector's
445
+ direction.
446
+ """
447
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
448
+
449
+
450
+ def axis_angle_to_quaternion(axis_angle):
451
+ """
452
+ Convert rotations given as axis/angle to quaternions.
453
+
454
+ Args:
455
+ axis_angle: Rotations given as a vector in axis angle form,
456
+ as a tensor of shape (..., 3), where the magnitude is
457
+ the angle turned anticlockwise in radians around the
458
+ vector's direction.
459
+
460
+ Returns:
461
+ quaternions with real part first, as tensor of shape (..., 4).
462
+ """
463
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
464
+ half_angles = 0.5 * angles
465
+ eps = 1e-6
466
+ small_angles = angles.abs() < eps
467
+ sin_half_angles_over_angles = torch.empty_like(angles)
468
+ sin_half_angles_over_angles[~small_angles] = (
469
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
470
+ )
471
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
472
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
473
+ sin_half_angles_over_angles[small_angles] = (
474
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
475
+ )
476
+ quaternions = torch.cat(
477
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
478
+ )
479
+ return quaternions
480
+
481
+
482
+ def quaternion_to_axis_angle(quaternions):
483
+ """
484
+ Convert rotations given as quaternions to axis/angle.
485
+
486
+ Args:
487
+ quaternions: quaternions with real part first,
488
+ as tensor of shape (..., 4).
489
+
490
+ Returns:
491
+ Rotations given as a vector in axis angle form, as a tensor
492
+ of shape (..., 3), where the magnitude is the angle
493
+ turned anticlockwise in radians around the vector's
494
+ direction.
495
+ """
496
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
497
+ half_angles = torch.atan2(norms, quaternions[..., :1])
498
+ angles = 2 * half_angles
499
+ eps = 1e-6
500
+ small_angles = angles.abs() < eps
501
+ sin_half_angles_over_angles = torch.empty_like(angles)
502
+ sin_half_angles_over_angles[~small_angles] = (
503
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
504
+ )
505
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
506
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
507
+ sin_half_angles_over_angles[small_angles] = (
508
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
509
+ )
510
+ return quaternions[..., 1:] / sin_half_angles_over_angles
511
+
512
+
513
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
514
+ """
515
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
516
+ using Gram--Schmidt orthogonalisation per Section B of [1].
517
+ Args:
518
+ d6: 6D rotation representation, of size (*, 6)
519
+
520
+ Returns:
521
+ batch of rotation matrices of size (*, 3, 3)
522
+
523
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
524
+ On the Continuity of Rotation Representations in Neural Networks.
525
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
526
+ Retrieved from http://arxiv.org/abs/1812.07035
527
+ """
528
+
529
+ a1, a2 = d6[..., :3], d6[..., 3:]
530
+ b1 = F.normalize(a1, dim=-1)
531
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
532
+ b2 = F.normalize(b2, dim=-1)
533
+ b3 = torch.cross(b1, b2, dim=-1)
534
+ return torch.stack((b1, b2, b3), dim=-2)
535
+
536
+
537
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
538
+ """
539
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
540
+ by dropping the last row. Note that 6D representation is not unique.
541
+ Args:
542
+ matrix: batch of rotation matrices of size (*, 3, 3)
543
+
544
+ Returns:
545
+ 6D rotation representation, of size (*, 6)
546
+
547
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
548
+ On the Continuity of Rotation Representations in Neural Networks.
549
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
550
+ Retrieved from http://arxiv.org/abs/1812.07035
551
+ """
552
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
utils/sampler_util.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from copy import deepcopy
5
+ from utils.misc import wrapped_getattr
6
+ import joblib
7
+
8
+ # A wrapper model for Classifier-free guidance **SAMPLING** only
9
+ # https://arxiv.org/abs/2207.12598
10
+ class ClassifierFreeSampleModel(nn.Module):
11
+
12
+ def __init__(self, model):
13
+ super().__init__()
14
+ self.model = model # model is the actual model to run
15
+
16
+ assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
17
+
18
+ # pointers to inner model
19
+ self.rot2xyz = self.model.rot2xyz
20
+ self.translation = self.model.translation
21
+ self.njoints = self.model.njoints
22
+ self.nfeats = self.model.nfeats
23
+ self.data_rep = self.model.data_rep
24
+ self.cond_mode = self.model.cond_mode
25
+ self.encode_text = self.model.encode_text
26
+
27
+ def forward(self, x, timesteps, y=None):
28
+ cond_mode = self.model.cond_mode
29
+ assert cond_mode in ['text', 'action']
30
+ y_uncond = deepcopy(y)
31
+ y_uncond['uncond'] = True
32
+ out = self.model(x, timesteps, y)
33
+ out_uncond = self.model(x, timesteps, y_uncond)
34
+ return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
35
+
36
+ def __getattr__(self, name, default=None):
37
+ # this method is reached only if name is not in self.__dict__.
38
+ return wrapped_getattr(self, name, default=None)
39
+
40
+
41
+ class AutoRegressiveSampler():
42
+ def __init__(self, args, sample_fn, required_frames=196):
43
+ self.sample_fn = sample_fn
44
+ self.args = args
45
+ self.required_frames = required_frames
46
+
47
+ def sample(self, model, shape, **kargs):
48
+ bs = shape[0]
49
+ n_iterations = (self.required_frames // self.args.pred_len) + int(self.required_frames % self.args.pred_len > 0)
50
+ samples_buf = []
51
+ cur_prefix = deepcopy(kargs['model_kwargs']['y']['prefix']) # init with data
52
+ dynamic_text_mode = type(kargs['model_kwargs']['y']['text'][0]) == list # Text changes on the fly - prompt per prediction is provided as a list (instead of a single prompt)
53
+ if self.args.autoregressive_include_prefix:
54
+ samples_buf.append(cur_prefix)
55
+ autoregressive_shape = list(deepcopy(shape))
56
+ autoregressive_shape[-1] = self.args.pred_len
57
+
58
+ # Autoregressive sampling
59
+ for i in range(n_iterations):
60
+
61
+ # Build the current kargs
62
+ cur_kargs = deepcopy(kargs)
63
+ cur_kargs['model_kwargs']['y']['prefix'] = cur_prefix
64
+ if dynamic_text_mode:
65
+ cur_kargs['model_kwargs']['y']['text'] = [s[i] for s in kargs['model_kwargs']['y']['text']]
66
+ if model.text_encoder_type == 'bert':
67
+ cur_kargs['model_kwargs']['y']['text_embed'] = (cur_kargs['model_kwargs']['y']['text_embed'][0][:, :, i], cur_kargs['model_kwargs']['y']['text_embed'][1][:, i])
68
+ else:
69
+ raise NotImplementedError('DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!')
70
+
71
+ # Sample the next prediction
72
+ sample = self.sample_fn(model, autoregressive_shape, **cur_kargs)
73
+
74
+ # Buffer the sample
75
+ samples_buf.append(sample.clone()[..., -self.args.pred_len:])
76
+
77
+ # Update the prefix
78
+ cur_prefix = sample.clone()[..., -self.args.context_len:]
79
+
80
+ full_batch = torch.cat(samples_buf, dim=-1)[..., :self.required_frames] # 200 -> 196
81
+ return full_batch