Spaces:
Paused
Paused
Upload 10 files
Browse files- utils/PYTORCH3D_LICENSE +30 -0
- utils/config.py +17 -0
- utils/dist_util.py +77 -0
- utils/fixseed.py +18 -0
- utils/loss_util.py +46 -0
- utils/misc.py +74 -0
- utils/model_util.py +132 -0
- utils/parser_util.py +320 -0
- utils/rotation_conversions.py +552 -0
- utils/sampler_util.py +81 -0
utils/PYTORCH3D_LICENSE
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD License
|
| 2 |
+
|
| 3 |
+
For PyTorch3D software
|
| 4 |
+
|
| 5 |
+
Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 6 |
+
|
| 7 |
+
Redistribution and use in source and binary forms, with or without modification,
|
| 8 |
+
are permitted provided that the following conditions are met:
|
| 9 |
+
|
| 10 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 11 |
+
list of conditions and the following disclaimer.
|
| 12 |
+
|
| 13 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 14 |
+
this list of conditions and the following disclaimer in the documentation
|
| 15 |
+
and/or other materials provided with the distribution.
|
| 16 |
+
|
| 17 |
+
* Neither the name Facebook nor the names of its contributors may be used to
|
| 18 |
+
endorse or promote products derived from this software without specific
|
| 19 |
+
prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
utils/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
SMPL_DATA_PATH = "./body_models/smpl"
|
| 4 |
+
|
| 5 |
+
SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
|
| 6 |
+
SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
|
| 7 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
|
| 8 |
+
|
| 9 |
+
ROT_CONVENTION_TO_ROT_NUMBER = {
|
| 10 |
+
'legacy': 23,
|
| 11 |
+
'no_hands': 21,
|
| 12 |
+
'full_hands': 51,
|
| 13 |
+
'mitten_hands': 33,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
GENDERS = ['neutral', 'male', 'female']
|
| 17 |
+
NUM_BETAS = 10
|
utils/dist_util.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for distributed training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import socket
|
| 6 |
+
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
|
| 10 |
+
# Change this to reflect your cluster layout.
|
| 11 |
+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
|
| 12 |
+
GPUS_PER_NODE = 8
|
| 13 |
+
|
| 14 |
+
SETUP_RETRY_COUNT = 3
|
| 15 |
+
|
| 16 |
+
used_device = 0
|
| 17 |
+
|
| 18 |
+
def setup_dist(device=0):
|
| 19 |
+
"""
|
| 20 |
+
Setup a distributed process group.
|
| 21 |
+
"""
|
| 22 |
+
global used_device
|
| 23 |
+
used_device = device
|
| 24 |
+
if dist.is_initialized():
|
| 25 |
+
return
|
| 26 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
|
| 27 |
+
|
| 28 |
+
# comm = MPI.COMM_WORLD
|
| 29 |
+
# backend = "gloo" if not th.cuda.is_available() else "nccl"
|
| 30 |
+
|
| 31 |
+
# if backend == "gloo":
|
| 32 |
+
# hostname = "localhost"
|
| 33 |
+
# else:
|
| 34 |
+
# hostname = socket.gethostbyname(socket.getfqdn())
|
| 35 |
+
# os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
|
| 36 |
+
# os.environ["RANK"] = str(comm.rank)
|
| 37 |
+
# os.environ["WORLD_SIZE"] = str(comm.size)
|
| 38 |
+
|
| 39 |
+
# port = comm.bcast(_find_free_port(), root=used_device)
|
| 40 |
+
# os.environ["MASTER_PORT"] = str(port)
|
| 41 |
+
# dist.init_process_group(backend=backend, init_method="env://")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def dev():
|
| 45 |
+
"""
|
| 46 |
+
Get the device to use for torch.distributed.
|
| 47 |
+
"""
|
| 48 |
+
global used_device
|
| 49 |
+
if th.cuda.is_available() and used_device>=0:
|
| 50 |
+
return th.device(f"cuda:{used_device}")
|
| 51 |
+
return th.device("cpu")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_state_dict(path, **kwargs):
|
| 55 |
+
"""
|
| 56 |
+
Load a PyTorch file without redundant fetches across MPI ranks.
|
| 57 |
+
"""
|
| 58 |
+
return th.load(path, **kwargs)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def sync_params(params):
|
| 62 |
+
"""
|
| 63 |
+
Synchronize a sequence of Tensors across ranks from rank 0.
|
| 64 |
+
"""
|
| 65 |
+
for p in params:
|
| 66 |
+
with th.no_grad():
|
| 67 |
+
dist.broadcast(p, 0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _find_free_port():
|
| 71 |
+
try:
|
| 72 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 73 |
+
s.bind(("", 0))
|
| 74 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 75 |
+
return s.getsockname()[1]
|
| 76 |
+
finally:
|
| 77 |
+
s.close()
|
utils/fixseed.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def fixseed(seed):
|
| 7 |
+
torch.backends.cudnn.benchmark = False
|
| 8 |
+
random.seed(seed)
|
| 9 |
+
np.random.seed(seed)
|
| 10 |
+
torch.manual_seed(seed)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# SEED = 10
|
| 14 |
+
# EVALSEED = 0
|
| 15 |
+
# # Provoc warning: not fully functionnal yet
|
| 16 |
+
# # torch.set_deterministic(True)
|
| 17 |
+
# torch.backends.cudnn.benchmark = False
|
| 18 |
+
# fixseed(SEED)
|
utils/loss_util.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusion.nn import mean_flat, sum_flat
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def angle_l2(angle1, angle2):
|
| 6 |
+
a = angle1 - angle2
|
| 7 |
+
a = (a + (torch.pi/2)) % torch.pi - (torch.pi/2)
|
| 8 |
+
return a ** 2
|
| 9 |
+
|
| 10 |
+
def diff_l2(a, b):
|
| 11 |
+
return (a - b) ** 2
|
| 12 |
+
|
| 13 |
+
def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True):
|
| 14 |
+
# assuming a.shape == b.shape == bs, J, Jdim, seqlen
|
| 15 |
+
# assuming mask.shape == bs, 1, 1, seqlen
|
| 16 |
+
loss = loss_fn(a, b)
|
| 17 |
+
loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements
|
| 18 |
+
n_entries = a.shape[1]
|
| 19 |
+
if len(a.shape) > 3:
|
| 20 |
+
n_entries *= a.shape[2]
|
| 21 |
+
non_zero_elements = sum_flat(mask)
|
| 22 |
+
if entries_norm:
|
| 23 |
+
# In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed,
|
| 24 |
+
# Otherwise set it to False
|
| 25 |
+
non_zero_elements *= n_entries
|
| 26 |
+
# print('mask', mask.shape)
|
| 27 |
+
# print('non_zero_elements', non_zero_elements)
|
| 28 |
+
# print('loss', loss)
|
| 29 |
+
mse_loss_val = loss / (non_zero_elements + epsilon) # Add epsilon to avoid division by zero
|
| 30 |
+
# print('mse_loss_val', mse_loss_val)
|
| 31 |
+
return mse_loss_val
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def masked_goal_l2(pred_goal, ref_goal, cond, all_goal_joint_names):
|
| 35 |
+
all_goal_joint_names_w_traj = np.append(all_goal_joint_names, 'traj')
|
| 36 |
+
target_joint_idx = [[np.where(all_goal_joint_names_w_traj == j)[0][0] for j in sample_joints] for sample_joints in cond['target_joint_names']]
|
| 37 |
+
loc_mask = torch.zeros_like(pred_goal[:,:-1], dtype=torch.bool)
|
| 38 |
+
for sample_idx in range(loc_mask.shape[0]):
|
| 39 |
+
loc_mask[sample_idx, target_joint_idx[sample_idx]] = True
|
| 40 |
+
loc_mask[:, -1, 1] = False # vertical joint of 'traj' is always masked out
|
| 41 |
+
loc_loss = masked_l2(pred_goal[:,:-1], ref_goal[:,:-1], loc_mask, entries_norm=False)
|
| 42 |
+
|
| 43 |
+
heading_loss = masked_l2(pred_goal[:,-1:, :1], ref_goal[:,-1:, :1], cond['is_heading'].unsqueeze(1).unsqueeze(1), loss_fn=angle_l2, entries_norm=False)
|
| 44 |
+
|
| 45 |
+
loss = loc_loss + heading_loss
|
| 46 |
+
return loss
|
utils/misc.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class WeightedSum(nn.Module):
|
| 6 |
+
def __init__(self, num_rows):
|
| 7 |
+
super(WeightedSum, self).__init__()
|
| 8 |
+
# Initialize learnable weights
|
| 9 |
+
self.weights = nn.Parameter(torch.randn(num_rows))
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
# Ensure weights are normalized (optional)
|
| 13 |
+
normalized_weights = self.weights / self.weights.sum() # torch.softmax(self.weights, dim=0)
|
| 14 |
+
# Compute the weighted sum of the rows
|
| 15 |
+
weighted_sum = torch.matmul(normalized_weights, x)
|
| 16 |
+
return weighted_sum
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def wrapped_getattr(self, name, default=None, wrapped_member_name='model'):
|
| 20 |
+
''' should be called from wrappers of model classes such as ClassifierFreeSampleModel'''
|
| 21 |
+
|
| 22 |
+
if isinstance(self, torch.nn.Module):
|
| 23 |
+
# for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules]
|
| 24 |
+
# so we activate nn.Module.__getattr__ first.
|
| 25 |
+
# Otherwise, we might encounter an infinite loop
|
| 26 |
+
try:
|
| 27 |
+
attr = torch.nn.Module.__getattr__(self, name)
|
| 28 |
+
except AttributeError:
|
| 29 |
+
wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name)
|
| 30 |
+
attr = getattr(wrapped_member, name, default)
|
| 31 |
+
else:
|
| 32 |
+
# the easy case, where self is not derived from nn.Module
|
| 33 |
+
wrapped_member = getattr(self, wrapped_member_name)
|
| 34 |
+
attr = getattr(wrapped_member, name, default)
|
| 35 |
+
return attr
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def to_numpy(tensor):
|
| 39 |
+
if torch.is_tensor(tensor):
|
| 40 |
+
return tensor.cpu().numpy()
|
| 41 |
+
elif type(tensor).__module__ != 'numpy':
|
| 42 |
+
raise ValueError("Cannot convert {} to numpy array".format(
|
| 43 |
+
type(tensor)))
|
| 44 |
+
return tensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def to_torch(ndarray):
|
| 48 |
+
if type(ndarray).__module__ == 'numpy':
|
| 49 |
+
return torch.from_numpy(ndarray)
|
| 50 |
+
elif not torch.is_tensor(ndarray):
|
| 51 |
+
raise ValueError("Cannot convert {} to torch tensor".format(
|
| 52 |
+
type(ndarray)))
|
| 53 |
+
return ndarray
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def cleanexit():
|
| 57 |
+
import sys
|
| 58 |
+
import os
|
| 59 |
+
try:
|
| 60 |
+
sys.exit(0)
|
| 61 |
+
except SystemExit:
|
| 62 |
+
os._exit(0)
|
| 63 |
+
|
| 64 |
+
def load_model_wo_clip(model, state_dict):
|
| 65 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 66 |
+
assert len(unexpected_keys) == 0
|
| 67 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
| 68 |
+
|
| 69 |
+
def freeze_joints(x, joints_to_freeze):
|
| 70 |
+
# Freezes selected joint *rotations* as they appear in the first frame
|
| 71 |
+
# x [bs, [root+n_joints], joint_dim(6), seqlen]
|
| 72 |
+
frozen = x.detach().clone()
|
| 73 |
+
frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1]
|
| 74 |
+
return frozen
|
utils/model_util.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from model.mdm import MDM
|
| 3 |
+
from diffusion import gaussian_diffusion as gd
|
| 4 |
+
from diffusion.respace import SpacedDiffusion, space_timesteps
|
| 5 |
+
from utils.parser_util import get_cond_mode
|
| 6 |
+
from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
|
| 7 |
+
|
| 8 |
+
def load_model_wo_clip(model, state_dict):
|
| 9 |
+
# assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST
|
| 10 |
+
# assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST
|
| 11 |
+
del state_dict['sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
|
| 12 |
+
del state_dict['embed_timestep.sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
|
| 13 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 14 |
+
assert len(unexpected_keys) == 0
|
| 15 |
+
assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_model_and_diffusion(args, data):
|
| 19 |
+
model = MDM(**get_model_args(args, data))
|
| 20 |
+
diffusion = create_gaussian_diffusion(args)
|
| 21 |
+
return model, diffusion
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_model_args(args, data):
|
| 25 |
+
|
| 26 |
+
# default args
|
| 27 |
+
clip_version = 'ViT-B/32'
|
| 28 |
+
action_emb = 'tensor'
|
| 29 |
+
cond_mode = get_cond_mode(args)
|
| 30 |
+
if hasattr(data.dataset, 'num_actions'):
|
| 31 |
+
num_actions = data.dataset.num_actions
|
| 32 |
+
else:
|
| 33 |
+
num_actions = 1
|
| 34 |
+
|
| 35 |
+
# SMPL defaults
|
| 36 |
+
data_rep = 'rot6d'
|
| 37 |
+
njoints = 25
|
| 38 |
+
nfeats = 6
|
| 39 |
+
all_goal_joint_names = []
|
| 40 |
+
|
| 41 |
+
if args.dataset == 'humanml':
|
| 42 |
+
data_rep = 'hml_vec'
|
| 43 |
+
njoints = 263
|
| 44 |
+
nfeats = 1
|
| 45 |
+
all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
|
| 46 |
+
elif args.dataset == 'kit':
|
| 47 |
+
data_rep = 'hml_vec'
|
| 48 |
+
njoints = 251
|
| 49 |
+
nfeats = 1
|
| 50 |
+
|
| 51 |
+
# Compatibility with old models
|
| 52 |
+
if not hasattr(args, 'pred_len'):
|
| 53 |
+
args.pred_len = 0
|
| 54 |
+
args.context_len = 0
|
| 55 |
+
|
| 56 |
+
emb_policy = args.__dict__.get('emb_policy', 'add')
|
| 57 |
+
multi_target_cond = args.__dict__.get('multi_target_cond', False)
|
| 58 |
+
multi_encoder_type = args.__dict__.get('multi_encoder_type', 'multi')
|
| 59 |
+
target_enc_layers = args.__dict__.get('target_enc_layers', 1)
|
| 60 |
+
|
| 61 |
+
return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
|
| 62 |
+
'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
|
| 63 |
+
'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
|
| 64 |
+
'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
|
| 65 |
+
'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
|
| 66 |
+
'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset,
|
| 67 |
+
'text_encoder_type': args.text_encoder_type,
|
| 68 |
+
'pos_embed_max_len': args.pos_embed_max_len, 'mask_frames': args.mask_frames,
|
| 69 |
+
'pred_len': args.pred_len, 'context_len': args.context_len, 'emb_policy': emb_policy,
|
| 70 |
+
'all_goal_joint_names': all_goal_joint_names, 'multi_target_cond': multi_target_cond, 'multi_encoder_type': multi_encoder_type, 'target_enc_layers': target_enc_layers,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def create_gaussian_diffusion(args):
|
| 76 |
+
# default params
|
| 77 |
+
predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
|
| 78 |
+
steps = args.diffusion_steps
|
| 79 |
+
scale_beta = 1. # no scaling
|
| 80 |
+
timestep_respacing = '' # can be used for ddim sampling, we don't use it.
|
| 81 |
+
learn_sigma = False
|
| 82 |
+
rescale_timesteps = False
|
| 83 |
+
|
| 84 |
+
betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
|
| 85 |
+
loss_type = gd.LossType.MSE
|
| 86 |
+
|
| 87 |
+
if not timestep_respacing:
|
| 88 |
+
timestep_respacing = [steps]
|
| 89 |
+
|
| 90 |
+
if hasattr(args, 'lambda_target_loc'):
|
| 91 |
+
lambda_target_loc = args.lambda_target_loc
|
| 92 |
+
else:
|
| 93 |
+
lambda_target_loc = 0.
|
| 94 |
+
|
| 95 |
+
return SpacedDiffusion(
|
| 96 |
+
use_timesteps=space_timesteps(steps, timestep_respacing),
|
| 97 |
+
betas=betas,
|
| 98 |
+
model_mean_type=(
|
| 99 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
| 100 |
+
),
|
| 101 |
+
model_var_type=(
|
| 102 |
+
(
|
| 103 |
+
gd.ModelVarType.FIXED_LARGE
|
| 104 |
+
if not args.sigma_small
|
| 105 |
+
else gd.ModelVarType.FIXED_SMALL
|
| 106 |
+
)
|
| 107 |
+
if not learn_sigma
|
| 108 |
+
else gd.ModelVarType.LEARNED_RANGE
|
| 109 |
+
),
|
| 110 |
+
loss_type=loss_type,
|
| 111 |
+
rescale_timesteps=rescale_timesteps,
|
| 112 |
+
lambda_vel=args.lambda_vel,
|
| 113 |
+
lambda_rcxyz=args.lambda_rcxyz,
|
| 114 |
+
lambda_fc=args.lambda_fc,
|
| 115 |
+
lambda_target_loc=lambda_target_loc,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def load_saved_model(model, model_path, use_avg: bool=False): # use_avg_model
|
| 119 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 120 |
+
# Use average model when possible
|
| 121 |
+
if use_avg and 'model_avg' in state_dict.keys():
|
| 122 |
+
# if use_avg_model:
|
| 123 |
+
print('loading avg model')
|
| 124 |
+
state_dict = state_dict['model_avg']
|
| 125 |
+
else:
|
| 126 |
+
if 'model' in state_dict:
|
| 127 |
+
print('loading model without avg')
|
| 128 |
+
state_dict = state_dict['model']
|
| 129 |
+
else:
|
| 130 |
+
print('checkpoint has no avg model, loading as usual.')
|
| 131 |
+
load_model_wo_clip(model, state_dict)
|
| 132 |
+
return model
|
utils/parser_util.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_and_load_from_model(parser):
|
| 8 |
+
# args according to the loaded model
|
| 9 |
+
# do not try to specify them from cmd line since they will be overwritten
|
| 10 |
+
add_data_options(parser)
|
| 11 |
+
add_model_options(parser)
|
| 12 |
+
add_diffusion_options(parser)
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
args_to_overwrite = []
|
| 15 |
+
for group_name in ['dataset', 'model', 'diffusion']:
|
| 16 |
+
args_to_overwrite += get_args_per_group_name(parser, args, group_name)
|
| 17 |
+
|
| 18 |
+
# load args from model
|
| 19 |
+
if args.model_path != '': # if not using external results file
|
| 20 |
+
args = load_args_from_model(args, args_to_overwrite)
|
| 21 |
+
|
| 22 |
+
if args.cond_mask_prob == 0:
|
| 23 |
+
args.guidance_param = 1
|
| 24 |
+
|
| 25 |
+
return apply_rules(args)
|
| 26 |
+
|
| 27 |
+
def load_args_from_model(args, args_to_overwrite):
|
| 28 |
+
model_path = get_model_path_from_args()
|
| 29 |
+
args_path = os.path.join(os.path.dirname(model_path), 'args.json')
|
| 30 |
+
assert os.path.exists(args_path), 'Arguments json file was not found!'
|
| 31 |
+
with open(args_path, 'r') as fr:
|
| 32 |
+
model_args = json.load(fr)
|
| 33 |
+
|
| 34 |
+
for a in args_to_overwrite:
|
| 35 |
+
if a in model_args.keys():
|
| 36 |
+
setattr(args, a, model_args[a])
|
| 37 |
+
|
| 38 |
+
elif 'cond_mode' in model_args: # backward compitability
|
| 39 |
+
unconstrained = (model_args['cond_mode'] == 'no_cond')
|
| 40 |
+
setattr(args, 'unconstrained', unconstrained)
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a]))
|
| 44 |
+
return args
|
| 45 |
+
|
| 46 |
+
def apply_rules(args):
|
| 47 |
+
# For prefix completion
|
| 48 |
+
if args.pred_len == 0:
|
| 49 |
+
args.pred_len = args.context_len
|
| 50 |
+
|
| 51 |
+
# For target conditioning
|
| 52 |
+
if args.lambda_target_loc > 0.:
|
| 53 |
+
args.multi_target_cond = True
|
| 54 |
+
return args
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_args_per_group_name(parser, args, group_name):
|
| 58 |
+
for group in parser._action_groups:
|
| 59 |
+
if group.title == group_name:
|
| 60 |
+
group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
|
| 61 |
+
return list(argparse.Namespace(**group_dict).__dict__.keys())
|
| 62 |
+
return ValueError('group_name was not found.')
|
| 63 |
+
|
| 64 |
+
def get_model_path_from_args():
|
| 65 |
+
try:
|
| 66 |
+
dummy_parser = ArgumentParser()
|
| 67 |
+
dummy_parser.add_argument('--model_path')
|
| 68 |
+
dummy_args, _ = dummy_parser.parse_known_args()
|
| 69 |
+
return dummy_args.model_path
|
| 70 |
+
except:
|
| 71 |
+
raise ValueError('model_path argument must be specified.')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def add_base_options(parser):
|
| 75 |
+
group = parser.add_argument_group('base')
|
| 76 |
+
group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.")
|
| 77 |
+
group.add_argument("--device", default=0, type=int, help="Device id to use.")
|
| 78 |
+
group.add_argument("--seed", default=10, type=int, help="For fixing random seed.")
|
| 79 |
+
group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.")
|
| 80 |
+
group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform', 'WandBPlatform'], type=str,
|
| 81 |
+
help="Choose platform to log results. NoPlatform means no logging.")
|
| 82 |
+
group.add_argument("--external_mode", default=False, type=bool, help="For backward cometability, do not change or delete.")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def add_diffusion_options(parser):
|
| 86 |
+
group = parser.add_argument_group('diffusion')
|
| 87 |
+
group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str,
|
| 88 |
+
help="Noise schedule type")
|
| 89 |
+
group.add_argument("--diffusion_steps", default=1000, type=int,
|
| 90 |
+
help="Number of diffusion steps (denoted T in the paper)")
|
| 91 |
+
group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def add_model_options(parser):
|
| 95 |
+
group = parser.add_argument_group('model')
|
| 96 |
+
group.add_argument("--arch", default='trans_enc',
|
| 97 |
+
choices=['trans_enc', 'trans_dec', 'gru'], type=str,
|
| 98 |
+
help="Architecture types as reported in the paper.")
|
| 99 |
+
group.add_argument("--text_encoder_type", default='clip',
|
| 100 |
+
choices=['clip', 'bert'], type=str, help="Text encoder type.")
|
| 101 |
+
group.add_argument("--emb_trans_dec", action='store_true',
|
| 102 |
+
help="For trans_dec architecture only, if true, will inject condition as a class token"
|
| 103 |
+
" (in addition to cross-attention).")
|
| 104 |
+
group.add_argument("--layers", default=8, type=int,
|
| 105 |
+
help="Number of layers.")
|
| 106 |
+
group.add_argument("--latent_dim", default=512, type=int,
|
| 107 |
+
help="Transformer/GRU width.")
|
| 108 |
+
group.add_argument("--cond_mask_prob", default=.1, type=float,
|
| 109 |
+
help="The probability of masking the condition during training."
|
| 110 |
+
" For classifier-free guidance learning.")
|
| 111 |
+
group.add_argument("--mask_frames", action='store_true', help="If true, will fix Rotem's bug and mask invalid frames.")
|
| 112 |
+
group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.")
|
| 113 |
+
group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.")
|
| 114 |
+
group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.")
|
| 115 |
+
group.add_argument("--lambda_target_loc", default=0.0, type=float, help="For HumanML only, when . L2 with target location.")
|
| 116 |
+
group.add_argument("--unconstrained", action='store_true',
|
| 117 |
+
help="Model is trained unconditionally. That is, it is constrained by neither text nor action. "
|
| 118 |
+
"Currently tested on HumanAct12 only.")
|
| 119 |
+
group.add_argument("--pos_embed_max_len", default=5000, type=int,
|
| 120 |
+
help="Pose embedding max length.")
|
| 121 |
+
group.add_argument("--use_ema", action='store_true',
|
| 122 |
+
help="If True, will use EMA model averaging.")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
group.add_argument("--multi_target_cond", action='store_true', help="If true, enable multi-target conditioning (aka Sigal's model).")
|
| 126 |
+
group.add_argument("--multi_encoder_type", default='single', choices=['single', 'multi', 'split'], type=str, help="Specifies the encoder type to be used for the multi joint condition.")
|
| 127 |
+
group.add_argument("--target_enc_layers", default=1, type=int, help="Num target encoder layers")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Prefix completion model
|
| 131 |
+
group.add_argument("--context_len", default=0, type=int, help="If larger than 0, will do prefix completion.")
|
| 132 |
+
group.add_argument("--pred_len", default=0, type=int, help="If context_len larger than 0, will do prefix completion. If pred_len will not be specified - will use the same length as context_len")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def add_data_options(parser):
|
| 138 |
+
group = parser.add_argument_group('dataset')
|
| 139 |
+
group.add_argument("--dataset", default='humanml', choices=['humanml', 'kit', 'humanact12', 'uestc'], type=str,
|
| 140 |
+
help="Dataset name (choose from list).")
|
| 141 |
+
group.add_argument("--data_dir", default="", type=str,
|
| 142 |
+
help="If empty, will use defaults according to the specified dataset.")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def add_training_options(parser):
|
| 146 |
+
group = parser.add_argument_group('training')
|
| 147 |
+
group.add_argument("--save_dir", required=True, type=str,
|
| 148 |
+
help="Path to save checkpoints and results.")
|
| 149 |
+
group.add_argument("--overwrite", action='store_true',
|
| 150 |
+
help="If True, will enable to use an already existing save_dir.")
|
| 151 |
+
group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.")
|
| 152 |
+
group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.")
|
| 153 |
+
group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.")
|
| 154 |
+
group.add_argument("--eval_batch_size", default=32, type=int,
|
| 155 |
+
help="Batch size during evaluation loop. Do not change this unless you know what you are doing. "
|
| 156 |
+
"T2m precision calculation is based on fixed batch size 32.")
|
| 157 |
+
group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str,
|
| 158 |
+
help="Which split to evaluate on during training.")
|
| 159 |
+
group.add_argument("--eval_during_training", action='store_true',
|
| 160 |
+
help="If True, will run evaluation during training.")
|
| 161 |
+
group.add_argument("--eval_rep_times", default=3, type=int,
|
| 162 |
+
help="Number of repetitions for evaluation loop during training.")
|
| 163 |
+
group.add_argument("--eval_num_samples", default=1_000, type=int,
|
| 164 |
+
help="If -1, will use all samples in the specified split.")
|
| 165 |
+
group.add_argument("--log_interval", default=1_000, type=int,
|
| 166 |
+
help="Log losses each N steps")
|
| 167 |
+
group.add_argument("--save_interval", default=50_000, type=int,
|
| 168 |
+
help="Save checkpoints and run evaluation each N steps")
|
| 169 |
+
group.add_argument("--num_steps", default=600_000, type=int,
|
| 170 |
+
help="Training will stop after the specified number of steps.")
|
| 171 |
+
group.add_argument("--num_frames", default=60, type=int,
|
| 172 |
+
help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.")
|
| 173 |
+
group.add_argument("--resume_checkpoint", default="", type=str,
|
| 174 |
+
help="If not empty, will start from the specified checkpoint (path to model###.pt file).")
|
| 175 |
+
|
| 176 |
+
group.add_argument("--gen_during_training", action='store_true',
|
| 177 |
+
help="If True, will generate motions during training, on each save interval.")
|
| 178 |
+
group.add_argument("--gen_num_samples", default=3, type=int,
|
| 179 |
+
help="Number of samples to sample while generating")
|
| 180 |
+
group.add_argument("--gen_num_repetitions", default=2, type=int,
|
| 181 |
+
help="Number of repetitions, per sample (text prompt/action)")
|
| 182 |
+
group.add_argument("--gen_guidance_param", default=2.5, type=float,
|
| 183 |
+
help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
|
| 184 |
+
|
| 185 |
+
group.add_argument("--avg_model_beta", default=0.9999, type=float, help="Average model beta (for EMA).")
|
| 186 |
+
group.add_argument("--adam_beta2", default=0.999, type=float, help="Adam beta2.")
|
| 187 |
+
|
| 188 |
+
group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
|
| 189 |
+
group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
|
| 190 |
+
group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
|
| 191 |
+
group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
|
| 192 |
+
help="Sets the source of the init frames, either from the dataset or isaac init poses.")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def add_sampling_options(parser):
|
| 196 |
+
group = parser.add_argument_group('sampling')
|
| 197 |
+
group.add_argument("--model_path", required=True, type=str,
|
| 198 |
+
help="Path to model####.pt file to be sampled.")
|
| 199 |
+
group.add_argument("--output_dir", default='', type=str,
|
| 200 |
+
help="Path to results dir (auto created by the script). "
|
| 201 |
+
"If empty, will create dir in parallel to checkpoint.")
|
| 202 |
+
group.add_argument("--num_samples", default=6, type=int,
|
| 203 |
+
help="Maximal number of prompts to sample, "
|
| 204 |
+
"if loading dataset from file, this field will be ignored.")
|
| 205 |
+
group.add_argument("--num_repetitions", default=3, type=int,
|
| 206 |
+
help="Number of repetitions, per sample (text prompt/action)")
|
| 207 |
+
group.add_argument("--guidance_param", default=2.5, type=float,
|
| 208 |
+
help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
|
| 209 |
+
|
| 210 |
+
group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
|
| 211 |
+
group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
|
| 212 |
+
group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
|
| 213 |
+
help="Sets the source of the init frames, either from the dataset or isaac init poses.")
|
| 214 |
+
|
| 215 |
+
def add_generate_options(parser):
|
| 216 |
+
group = parser.add_argument_group('generate')
|
| 217 |
+
group.add_argument("--motion_length", default=6.0, type=float,
|
| 218 |
+
help="The length of the sampled motion [in seconds]. "
|
| 219 |
+
"Maximum is 9.8 for HumanML3D (text-to-motion), and 2.0 for HumanAct12 (action-to-motion)")
|
| 220 |
+
group.add_argument("--input_text", default='', type=str,
|
| 221 |
+
help="Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
|
| 222 |
+
group.add_argument("--dynamic_text_path", default='', type=str,
|
| 223 |
+
help="For the autoregressive mode only! Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
|
| 224 |
+
group.add_argument("--action_file", default='', type=str,
|
| 225 |
+
help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, "
|
| 226 |
+
"or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. "
|
| 227 |
+
"If no file is specified, will take action names from dataset.")
|
| 228 |
+
group.add_argument("--text_prompt", default='', type=str,
|
| 229 |
+
help="A text prompt to be generated. If empty, will take text prompts from dataset.")
|
| 230 |
+
group.add_argument("--action_name", default='', type=str,
|
| 231 |
+
help="An action name to be generated. If empty, will take text prompts from dataset.")
|
| 232 |
+
group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def add_edit_options(parser):
|
| 236 |
+
group = parser.add_argument_group('edit')
|
| 237 |
+
group.add_argument("--edit_mode", default='in_between', choices=['in_between', 'upper_body'], type=str,
|
| 238 |
+
help="Defines which parts of the input motion will be edited.\n"
|
| 239 |
+
"(1) in_between - suffix and prefix motion taken from input motion, "
|
| 240 |
+
"middle motion is generated.\n"
|
| 241 |
+
"(2) upper_body - lower body joints taken from input motion, "
|
| 242 |
+
"upper body is generated.")
|
| 243 |
+
group.add_argument("--text_condition", default='', type=str,
|
| 244 |
+
help="Editing will be conditioned on this text prompt. "
|
| 245 |
+
"If empty, will perform unconditioned editing.")
|
| 246 |
+
group.add_argument("--prefix_end", default=0.25, type=float,
|
| 247 |
+
help="For in_between editing - Defines the end of input prefix (ratio from all frames).")
|
| 248 |
+
group.add_argument("--suffix_start", default=0.75, type=float,
|
| 249 |
+
help="For in_between editing - Defines the start of input suffix (ratio from all frames).")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def add_evaluation_options(parser):
|
| 253 |
+
group = parser.add_argument_group('eval')
|
| 254 |
+
group.add_argument("--model_path", required=True, type=str,
|
| 255 |
+
help="Path to model####.pt file to be sampled.")
|
| 256 |
+
group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug', 'full'], type=str,
|
| 257 |
+
help="wo_mm (t2m only) - 20 repetitions without multi-modality metric; "
|
| 258 |
+
"mm_short (t2m only) - 5 repetitions with multi-modality metric; "
|
| 259 |
+
"debug - short run, less accurate results."
|
| 260 |
+
"full (a2m only) - 20 repetitions.")
|
| 261 |
+
group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
|
| 262 |
+
group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
|
| 263 |
+
group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
|
| 264 |
+
help="Sets the source of the init frames, either from the dataset or isaac init poses.")
|
| 265 |
+
group.add_argument("--guidance_param", default=2.5, type=float,
|
| 266 |
+
help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def get_cond_mode(args):
|
| 270 |
+
if args.unconstrained:
|
| 271 |
+
cond_mode = 'no_cond'
|
| 272 |
+
elif args.dataset in ['kit', 'humanml']:
|
| 273 |
+
cond_mode = 'text'
|
| 274 |
+
else:
|
| 275 |
+
cond_mode = 'action'
|
| 276 |
+
return cond_mode
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def train_args():
|
| 280 |
+
parser = ArgumentParser()
|
| 281 |
+
add_base_options(parser)
|
| 282 |
+
add_data_options(parser)
|
| 283 |
+
add_model_options(parser)
|
| 284 |
+
add_diffusion_options(parser)
|
| 285 |
+
add_training_options(parser)
|
| 286 |
+
return apply_rules(parser.parse_args())
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def generate_args():
|
| 290 |
+
parser = ArgumentParser()
|
| 291 |
+
# args specified by the user: (all other will be loaded from the model)
|
| 292 |
+
add_base_options(parser)
|
| 293 |
+
add_sampling_options(parser)
|
| 294 |
+
add_generate_options(parser)
|
| 295 |
+
args = parse_and_load_from_model(parser)
|
| 296 |
+
cond_mode = get_cond_mode(args)
|
| 297 |
+
|
| 298 |
+
if (args.input_text or args.text_prompt) and cond_mode != 'text':
|
| 299 |
+
raise Exception('Arguments input_text and text_prompt should not be used for an action condition. Please use action_file or action_name.')
|
| 300 |
+
elif (args.action_file or args.action_name) and cond_mode != 'action':
|
| 301 |
+
raise Exception('Arguments action_file and action_name should not be used for a text condition. Please use input_text or text_prompt.')
|
| 302 |
+
|
| 303 |
+
return args
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def edit_args():
|
| 307 |
+
parser = ArgumentParser()
|
| 308 |
+
# args specified by the user: (all other will be loaded from the model)
|
| 309 |
+
add_base_options(parser)
|
| 310 |
+
add_sampling_options(parser)
|
| 311 |
+
add_edit_options(parser)
|
| 312 |
+
return parse_and_load_from_model(parser)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def evaluation_parser():
|
| 316 |
+
parser = ArgumentParser()
|
| 317 |
+
# args specified by the user: (all other will be loaded from the model)
|
| 318 |
+
add_base_options(parser)
|
| 319 |
+
add_evaluation_options(parser)
|
| 320 |
+
return parse_and_load_from_model(parser)
|
utils/rotation_conversions.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is based on https://github.com/Mathux/ACTOR.git
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
| 3 |
+
# Check PYTORCH3D_LICENCE before use
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
The transformation matrices returned from the functions in this file assume
|
| 14 |
+
the points on which the transformation will be applied are column vectors.
|
| 15 |
+
i.e. the R matrix is structured as
|
| 16 |
+
|
| 17 |
+
R = [
|
| 18 |
+
[Rxx, Rxy, Rxz],
|
| 19 |
+
[Ryx, Ryy, Ryz],
|
| 20 |
+
[Rzx, Rzy, Rzz],
|
| 21 |
+
] # (3, 3)
|
| 22 |
+
|
| 23 |
+
This matrix can be applied to column vectors by post multiplication
|
| 24 |
+
by the points e.g.
|
| 25 |
+
|
| 26 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
| 27 |
+
transformed_points = R * points
|
| 28 |
+
|
| 29 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
| 30 |
+
can be transposed and pre multiplied by the points:
|
| 31 |
+
|
| 32 |
+
e.g.
|
| 33 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
| 34 |
+
transformed_points = points * R.transpose(1, 0)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def quaternion_to_matrix(quaternions):
|
| 39 |
+
"""
|
| 40 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
quaternions: quaternions with real part first,
|
| 44 |
+
as tensor of shape (..., 4).
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 48 |
+
"""
|
| 49 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 50 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 51 |
+
|
| 52 |
+
o = torch.stack(
|
| 53 |
+
(
|
| 54 |
+
1 - two_s * (j * j + k * k),
|
| 55 |
+
two_s * (i * j - k * r),
|
| 56 |
+
two_s * (i * k + j * r),
|
| 57 |
+
two_s * (i * j + k * r),
|
| 58 |
+
1 - two_s * (i * i + k * k),
|
| 59 |
+
two_s * (j * k - i * r),
|
| 60 |
+
two_s * (i * k - j * r),
|
| 61 |
+
two_s * (j * k + i * r),
|
| 62 |
+
1 - two_s * (i * i + j * j),
|
| 63 |
+
),
|
| 64 |
+
-1,
|
| 65 |
+
)
|
| 66 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _copysign(a, b):
|
| 70 |
+
"""
|
| 71 |
+
Return a tensor where each element has the absolute value taken from the,
|
| 72 |
+
corresponding element of a, with sign taken from the corresponding
|
| 73 |
+
element of b. This is like the standard copysign floating-point operation,
|
| 74 |
+
but is not careful about negative 0 and NaN.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
a: source tensor.
|
| 78 |
+
b: tensor whose signs will be used, of the same shape as a.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Tensor of the same shape as a with the signs of b.
|
| 82 |
+
"""
|
| 83 |
+
signs_differ = (a < 0) != (b < 0)
|
| 84 |
+
return torch.where(signs_differ, -a, a)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _sqrt_positive_part(x):
|
| 88 |
+
"""
|
| 89 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 90 |
+
but with a zero subgradient where x is 0.
|
| 91 |
+
"""
|
| 92 |
+
ret = torch.zeros_like(x)
|
| 93 |
+
positive_mask = x > 0
|
| 94 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 95 |
+
return ret
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def matrix_to_quaternion(matrix):
|
| 99 |
+
"""
|
| 100 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 107 |
+
"""
|
| 108 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 109 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 110 |
+
m00 = matrix[..., 0, 0]
|
| 111 |
+
m11 = matrix[..., 1, 1]
|
| 112 |
+
m22 = matrix[..., 2, 2]
|
| 113 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
| 114 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
| 115 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
| 116 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
| 117 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
| 118 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
| 119 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
| 120 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _axis_angle_rotation(axis: str, angle):
|
| 124 |
+
"""
|
| 125 |
+
Return the rotation matrices for one of the rotations about an axis
|
| 126 |
+
of which Euler angles describe, for each value of the angle given.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
axis: Axis label "X" or "Y or "Z".
|
| 130 |
+
angle: any shape tensor of Euler angles in radians
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
cos = torch.cos(angle)
|
| 137 |
+
sin = torch.sin(angle)
|
| 138 |
+
one = torch.ones_like(angle)
|
| 139 |
+
zero = torch.zeros_like(angle)
|
| 140 |
+
|
| 141 |
+
if axis == "X":
|
| 142 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
| 143 |
+
if axis == "Y":
|
| 144 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
| 145 |
+
if axis == "Z":
|
| 146 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
| 147 |
+
|
| 148 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
| 152 |
+
"""
|
| 153 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
| 157 |
+
convention: Convention string of three uppercase letters from
|
| 158 |
+
{"X", "Y", and "Z"}.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 162 |
+
"""
|
| 163 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
| 164 |
+
raise ValueError("Invalid input euler angles.")
|
| 165 |
+
if len(convention) != 3:
|
| 166 |
+
raise ValueError("Convention must have 3 letters.")
|
| 167 |
+
if convention[1] in (convention[0], convention[2]):
|
| 168 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 169 |
+
for letter in convention:
|
| 170 |
+
if letter not in ("X", "Y", "Z"):
|
| 171 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 172 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
| 173 |
+
return functools.reduce(torch.matmul, matrices)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _angle_from_tan(
|
| 177 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Extract the first or third Euler angle from the two members of
|
| 181 |
+
the matrix which are positive constant times its sine and cosine.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
| 185 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
| 186 |
+
convention.
|
| 187 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
| 188 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
| 189 |
+
which means the relevant entries are in the same row of the
|
| 190 |
+
rotation matrix. If not, they are in the same column.
|
| 191 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Euler Angles in radians for each matrix in dataset as a tensor
|
| 195 |
+
of shape (...).
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
| 199 |
+
if horizontal:
|
| 200 |
+
i2, i1 = i1, i2
|
| 201 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
| 202 |
+
if horizontal == even:
|
| 203 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
| 204 |
+
if tait_bryan:
|
| 205 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
| 206 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _index_from_letter(letter: str):
|
| 210 |
+
if letter == "X":
|
| 211 |
+
return 0
|
| 212 |
+
if letter == "Y":
|
| 213 |
+
return 1
|
| 214 |
+
if letter == "Z":
|
| 215 |
+
return 2
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
| 219 |
+
"""
|
| 220 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 224 |
+
convention: Convention string of three uppercase letters.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Euler angles in radians as tensor of shape (..., 3).
|
| 228 |
+
"""
|
| 229 |
+
if len(convention) != 3:
|
| 230 |
+
raise ValueError("Convention must have 3 letters.")
|
| 231 |
+
if convention[1] in (convention[0], convention[2]):
|
| 232 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 233 |
+
for letter in convention:
|
| 234 |
+
if letter not in ("X", "Y", "Z"):
|
| 235 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 236 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 237 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 238 |
+
i0 = _index_from_letter(convention[0])
|
| 239 |
+
i2 = _index_from_letter(convention[2])
|
| 240 |
+
tait_bryan = i0 != i2
|
| 241 |
+
if tait_bryan:
|
| 242 |
+
central_angle = torch.asin(
|
| 243 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
| 247 |
+
|
| 248 |
+
o = (
|
| 249 |
+
_angle_from_tan(
|
| 250 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
| 251 |
+
),
|
| 252 |
+
central_angle,
|
| 253 |
+
_angle_from_tan(
|
| 254 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
| 255 |
+
),
|
| 256 |
+
)
|
| 257 |
+
return torch.stack(o, -1)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def random_quaternions(
|
| 261 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 262 |
+
):
|
| 263 |
+
"""
|
| 264 |
+
Generate random quaternions representing rotations,
|
| 265 |
+
i.e. versors with nonnegative real part.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
n: Number of quaternions in a batch to return.
|
| 269 |
+
dtype: Type to return.
|
| 270 |
+
device: Desired device of returned tensor. Default:
|
| 271 |
+
uses the current device for the default tensor type.
|
| 272 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 273 |
+
flag set.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Quaternions as tensor of shape (N, 4).
|
| 277 |
+
"""
|
| 278 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
| 279 |
+
s = (o * o).sum(1)
|
| 280 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
| 281 |
+
return o
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def random_rotations(
|
| 285 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 286 |
+
):
|
| 287 |
+
"""
|
| 288 |
+
Generate random rotations as 3x3 rotation matrices.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
n: Number of rotation matrices in a batch to return.
|
| 292 |
+
dtype: Type to return.
|
| 293 |
+
device: Device of returned tensor. Default: if None,
|
| 294 |
+
uses the current device for the default tensor type.
|
| 295 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 296 |
+
flag set.
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
| 300 |
+
"""
|
| 301 |
+
quaternions = random_quaternions(
|
| 302 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
| 303 |
+
)
|
| 304 |
+
return quaternion_to_matrix(quaternions)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def random_rotation(
|
| 308 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Generate a single random 3x3 rotation matrix.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
dtype: Type to return
|
| 315 |
+
device: Device of returned tensor. Default: if None,
|
| 316 |
+
uses the current device for the default tensor type
|
| 317 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
| 318 |
+
flag set
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
Rotation matrix as tensor of shape (3, 3).
|
| 322 |
+
"""
|
| 323 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def standardize_quaternion(quaternions):
|
| 327 |
+
"""
|
| 328 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 329 |
+
part is non negative.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
quaternions: Quaternions with real part first,
|
| 333 |
+
as tensor of shape (..., 4).
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 337 |
+
"""
|
| 338 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def quaternion_raw_multiply(a, b):
|
| 342 |
+
"""
|
| 343 |
+
Multiply two quaternions.
|
| 344 |
+
Usual torch rules for broadcasting apply.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 348 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
| 352 |
+
"""
|
| 353 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
| 354 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
| 355 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
| 356 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
| 357 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
| 358 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
| 359 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def quaternion_multiply(a, b):
|
| 363 |
+
"""
|
| 364 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
| 365 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
| 366 |
+
Usual torch rules for broadcasting apply.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 370 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
| 374 |
+
"""
|
| 375 |
+
ab = quaternion_raw_multiply(a, b)
|
| 376 |
+
return standardize_quaternion(ab)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def quaternion_invert(quaternion):
|
| 380 |
+
"""
|
| 381 |
+
Given a quaternion representing rotation, get the quaternion representing
|
| 382 |
+
its inverse.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
| 386 |
+
first, which must be versors (unit quaternions).
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def quaternion_apply(quaternion, point):
|
| 396 |
+
"""
|
| 397 |
+
Apply the rotation given by a quaternion to a 3D point.
|
| 398 |
+
Usual torch rules for broadcasting apply.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
| 402 |
+
point: Tensor of 3D points of shape (..., 3).
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Tensor of rotated points of shape (..., 3).
|
| 406 |
+
"""
|
| 407 |
+
if point.size(-1) != 3:
|
| 408 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
| 409 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
| 410 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
| 411 |
+
out = quaternion_raw_multiply(
|
| 412 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
| 413 |
+
quaternion_invert(quaternion),
|
| 414 |
+
)
|
| 415 |
+
return out[..., 1:]
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def axis_angle_to_matrix(axis_angle):
|
| 419 |
+
"""
|
| 420 |
+
Convert rotations given as axis/angle to rotation matrices.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 424 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 425 |
+
the angle turned anticlockwise in radians around the
|
| 426 |
+
vector's direction.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 430 |
+
"""
|
| 431 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def matrix_to_axis_angle(matrix):
|
| 435 |
+
"""
|
| 436 |
+
Convert rotations given as rotation matrices to axis/angle.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 443 |
+
of shape (..., 3), where the magnitude is the angle
|
| 444 |
+
turned anticlockwise in radians around the vector's
|
| 445 |
+
direction.
|
| 446 |
+
"""
|
| 447 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def axis_angle_to_quaternion(axis_angle):
|
| 451 |
+
"""
|
| 452 |
+
Convert rotations given as axis/angle to quaternions.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 456 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 457 |
+
the angle turned anticlockwise in radians around the
|
| 458 |
+
vector's direction.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 462 |
+
"""
|
| 463 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
| 464 |
+
half_angles = 0.5 * angles
|
| 465 |
+
eps = 1e-6
|
| 466 |
+
small_angles = angles.abs() < eps
|
| 467 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 468 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 469 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 470 |
+
)
|
| 471 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 472 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 473 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 474 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 475 |
+
)
|
| 476 |
+
quaternions = torch.cat(
|
| 477 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
| 478 |
+
)
|
| 479 |
+
return quaternions
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def quaternion_to_axis_angle(quaternions):
|
| 483 |
+
"""
|
| 484 |
+
Convert rotations given as quaternions to axis/angle.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
quaternions: quaternions with real part first,
|
| 488 |
+
as tensor of shape (..., 4).
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 492 |
+
of shape (..., 3), where the magnitude is the angle
|
| 493 |
+
turned anticlockwise in radians around the vector's
|
| 494 |
+
direction.
|
| 495 |
+
"""
|
| 496 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 497 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 498 |
+
angles = 2 * half_angles
|
| 499 |
+
eps = 1e-6
|
| 500 |
+
small_angles = angles.abs() < eps
|
| 501 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 502 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 503 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 504 |
+
)
|
| 505 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 506 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 507 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 508 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 509 |
+
)
|
| 510 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
| 514 |
+
"""
|
| 515 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
| 516 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
| 517 |
+
Args:
|
| 518 |
+
d6: 6D rotation representation, of size (*, 6)
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
batch of rotation matrices of size (*, 3, 3)
|
| 522 |
+
|
| 523 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 524 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 525 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 526 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 530 |
+
b1 = F.normalize(a1, dim=-1)
|
| 531 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 532 |
+
b2 = F.normalize(b2, dim=-1)
|
| 533 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 534 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
| 538 |
+
"""
|
| 539 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
| 540 |
+
by dropping the last row. Note that 6D representation is not unique.
|
| 541 |
+
Args:
|
| 542 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
6D rotation representation, of size (*, 6)
|
| 546 |
+
|
| 547 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 548 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 549 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 550 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 551 |
+
"""
|
| 552 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
utils/sampler_util.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from utils.misc import wrapped_getattr
|
| 6 |
+
import joblib
|
| 7 |
+
|
| 8 |
+
# A wrapper model for Classifier-free guidance **SAMPLING** only
|
| 9 |
+
# https://arxiv.org/abs/2207.12598
|
| 10 |
+
class ClassifierFreeSampleModel(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, model):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = model # model is the actual model to run
|
| 15 |
+
|
| 16 |
+
assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
|
| 17 |
+
|
| 18 |
+
# pointers to inner model
|
| 19 |
+
self.rot2xyz = self.model.rot2xyz
|
| 20 |
+
self.translation = self.model.translation
|
| 21 |
+
self.njoints = self.model.njoints
|
| 22 |
+
self.nfeats = self.model.nfeats
|
| 23 |
+
self.data_rep = self.model.data_rep
|
| 24 |
+
self.cond_mode = self.model.cond_mode
|
| 25 |
+
self.encode_text = self.model.encode_text
|
| 26 |
+
|
| 27 |
+
def forward(self, x, timesteps, y=None):
|
| 28 |
+
cond_mode = self.model.cond_mode
|
| 29 |
+
assert cond_mode in ['text', 'action']
|
| 30 |
+
y_uncond = deepcopy(y)
|
| 31 |
+
y_uncond['uncond'] = True
|
| 32 |
+
out = self.model(x, timesteps, y)
|
| 33 |
+
out_uncond = self.model(x, timesteps, y_uncond)
|
| 34 |
+
return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
|
| 35 |
+
|
| 36 |
+
def __getattr__(self, name, default=None):
|
| 37 |
+
# this method is reached only if name is not in self.__dict__.
|
| 38 |
+
return wrapped_getattr(self, name, default=None)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AutoRegressiveSampler():
|
| 42 |
+
def __init__(self, args, sample_fn, required_frames=196):
|
| 43 |
+
self.sample_fn = sample_fn
|
| 44 |
+
self.args = args
|
| 45 |
+
self.required_frames = required_frames
|
| 46 |
+
|
| 47 |
+
def sample(self, model, shape, **kargs):
|
| 48 |
+
bs = shape[0]
|
| 49 |
+
n_iterations = (self.required_frames // self.args.pred_len) + int(self.required_frames % self.args.pred_len > 0)
|
| 50 |
+
samples_buf = []
|
| 51 |
+
cur_prefix = deepcopy(kargs['model_kwargs']['y']['prefix']) # init with data
|
| 52 |
+
dynamic_text_mode = type(kargs['model_kwargs']['y']['text'][0]) == list # Text changes on the fly - prompt per prediction is provided as a list (instead of a single prompt)
|
| 53 |
+
if self.args.autoregressive_include_prefix:
|
| 54 |
+
samples_buf.append(cur_prefix)
|
| 55 |
+
autoregressive_shape = list(deepcopy(shape))
|
| 56 |
+
autoregressive_shape[-1] = self.args.pred_len
|
| 57 |
+
|
| 58 |
+
# Autoregressive sampling
|
| 59 |
+
for i in range(n_iterations):
|
| 60 |
+
|
| 61 |
+
# Build the current kargs
|
| 62 |
+
cur_kargs = deepcopy(kargs)
|
| 63 |
+
cur_kargs['model_kwargs']['y']['prefix'] = cur_prefix
|
| 64 |
+
if dynamic_text_mode:
|
| 65 |
+
cur_kargs['model_kwargs']['y']['text'] = [s[i] for s in kargs['model_kwargs']['y']['text']]
|
| 66 |
+
if model.text_encoder_type == 'bert':
|
| 67 |
+
cur_kargs['model_kwargs']['y']['text_embed'] = (cur_kargs['model_kwargs']['y']['text_embed'][0][:, :, i], cur_kargs['model_kwargs']['y']['text_embed'][1][:, i])
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError('DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!')
|
| 70 |
+
|
| 71 |
+
# Sample the next prediction
|
| 72 |
+
sample = self.sample_fn(model, autoregressive_shape, **cur_kargs)
|
| 73 |
+
|
| 74 |
+
# Buffer the sample
|
| 75 |
+
samples_buf.append(sample.clone()[..., -self.args.pred_len:])
|
| 76 |
+
|
| 77 |
+
# Update the prefix
|
| 78 |
+
cur_prefix = sample.clone()[..., -self.args.context_len:]
|
| 79 |
+
|
| 80 |
+
full_batch = torch.cat(samples_buf, dim=-1)[..., :self.required_frames] # 200 -> 196
|
| 81 |
+
return full_batch
|