Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: ps-license@tuebingen.mpg.de | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| from einops import rearrange | |
| from mGPT.utils.easyconvert import rep_to_rep, nfeats_of, to_matrix | |
| import mGPT.utils.geometry_tools as geometry_tools | |
| from .base import Rots2Rfeats | |
| class Globalvelandy(Rots2Rfeats): | |
| def __init__(self, | |
| path: Optional[str] = None, | |
| normalization: bool = False, | |
| pose_rep: str = "rot6d", | |
| canonicalize: bool = False, | |
| offset: bool = True, | |
| **kwargs) -> None: | |
| super().__init__(path=path, normalization=normalization) | |
| self.canonicalize = canonicalize | |
| self.pose_rep = pose_rep | |
| self.nfeats = nfeats_of(pose_rep) | |
| self.offset = offset | |
| def forward(self, data, data_rep='matrix', first_frame=None) -> Tensor: | |
| poses, trans = data.rots, data.trans | |
| # extract the root gravity axis | |
| # for smpl it is the last coordinate | |
| root_y = trans[..., 2] | |
| trajectory = trans[..., [0, 1]] | |
| # Compute the difference of trajectory | |
| vel_trajectory = torch.diff(trajectory, dim=-2) | |
| # 0 for the first one => keep the dimentionality | |
| if first_frame is None: | |
| first_frame = 0 * vel_trajectory[..., [0], :] | |
| vel_trajectory = torch.cat((first_frame, vel_trajectory), dim=-2) | |
| # first normalize the data | |
| if self.canonicalize: | |
| matrix_poses = rep_to_rep(data_rep, 'matrix', poses) | |
| global_orient = matrix_poses[..., 0, :, :] | |
| # remove the rotation | |
| rot2d = rep_to_rep(data_rep, 'rotvec', poses[0, 0, ...]) | |
| # Remove the fist rotation along the vertical axis | |
| rot2d[..., :2] = 0 | |
| if self.offset: | |
| # add a bit more rotation | |
| rot2d[..., 2] += torch.pi / 2 | |
| rot2d = rep_to_rep('rotvec', 'matrix', rot2d) | |
| # turn with the same amount all the rotations | |
| global_orient = torch.einsum("...kj,...kl->...jl", rot2d, | |
| global_orient) | |
| matrix_poses = torch.cat( | |
| (global_orient[..., None, :, :], matrix_poses[..., 1:, :, :]), | |
| dim=-3) | |
| poses = rep_to_rep('matrix', data_rep, matrix_poses) | |
| # Turn the trajectory as well | |
| vel_trajectory = torch.einsum("...kj,...lk->...lj", | |
| rot2d[..., :2, :2], vel_trajectory) | |
| poses = rep_to_rep(data_rep, self.pose_rep, poses) | |
| features = torch.cat( | |
| (root_y[..., None], vel_trajectory, | |
| rearrange(poses, "... joints rot -> ... (joints rot)")), | |
| dim=-1) | |
| features = self.normalize(features) | |
| return features | |
| def extract(self, features): | |
| root_y = features[..., 0] | |
| vel_trajectory = features[..., 1:3] | |
| poses_features = features[..., 3:] | |
| poses = rearrange(poses_features, | |
| "... (joints rot) -> ... joints rot", | |
| rot=self.nfeats) | |
| return root_y, vel_trajectory, poses | |
| def inverse(self, features, last_frame=None): | |
| features = self.unnormalize(features) | |
| root_y, vel_trajectory, poses = self.extract(features) | |
| # integrate the trajectory | |
| trajectory = torch.cumsum(vel_trajectory, dim=-2) | |
| if last_frame is None: | |
| pass | |
| # First frame should be 0, but if infered it is better to ensure it | |
| trajectory = trajectory - trajectory[..., [0], :] | |
| # Get back the translation | |
| trans = torch.cat([trajectory, root_y[..., None]], dim=-1) | |
| matrix_poses = rep_to_rep(self.pose_rep, 'matrix', poses) | |
| from ..smpl import RotTransDatastruct | |
| return RotTransDatastruct(rots=matrix_poses, trans=trans) | |