| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| from typing import Optional |
|
|
| import torch |
| from einops import rearrange |
| from torch import Tensor |
| from mGPT.utils.joints import smplh_to_mmm_scaling_factor |
| from mGPT.utils.joints import smplh2mmm_indexes |
| from .base import Rots2Joints |
|
|
|
|
| def slice_or_none(data, cslice): |
| if data is None: |
| return data |
| else: |
| return data[cslice] |
|
|
|
|
| class SMPLH(Rots2Joints): |
|
|
| def __init__(self, |
| path: str, |
| jointstype: str = "mmm", |
| input_pose_rep: str = "matrix", |
| batch_size: int = 512, |
| gender="neutral", |
| **kwargs) -> None: |
| super().__init__(path=None, normalization=False) |
| self.batch_size = batch_size |
| self.input_pose_rep = input_pose_rep |
| self.jointstype = jointstype |
| self.training = False |
|
|
| from smplx.body_models import SMPLHLayer |
| import os |
| |
| |
| |
|
|
| |
| with contextlib.redirect_stdout(None): |
| self.smplh = SMPLHLayer(path, ext="pkl", gender=gender).eval() |
|
|
| self.faces = self.smplh.faces |
| for p in self.parameters(): |
| p.requires_grad = False |
|
|
| def train(self, *args, **kwargs): |
| return self |
|
|
| def forward(self, |
| smpl_data: dict, |
| jointstype: Optional[str] = None, |
| input_pose_rep: Optional[str] = None, |
| batch_size: Optional[int] = None) -> Tensor: |
|
|
| |
| jointstype = self.jointstype if jointstype is None else jointstype |
| batch_size = self.batch_size if batch_size is None else batch_size |
| input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep |
|
|
| if input_pose_rep == "xyz": |
| raise NotImplementedError( |
| "You should use identity pose2joints instead") |
|
|
| poses = smpl_data.rots |
| trans = smpl_data.trans |
|
|
| from functools import reduce |
| import operator |
| save_shape_bs_len = poses.shape[:-3] |
| nposes = reduce(operator.mul, save_shape_bs_len, 1) |
|
|
| if poses.shape[-3] == 52: |
| nohands = False |
| elif poses.shape[-3] == 22: |
| nohands = True |
| else: |
| raise NotImplementedError("Could not parse the poses.") |
|
|
| |
| |
| |
| matrix_poses = poses |
|
|
| |
| matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) |
| global_orient = matrix_poses[:, 0] |
|
|
| if trans is None: |
| trans = torch.zeros((*save_shape_bs_len, 3), |
| dtype=poses.dtype, |
| device=poses.device) |
|
|
| trans_all = trans.reshape((nposes, *trans.shape[-1:])) |
|
|
| body_pose = matrix_poses[:, 1:22] |
| if nohands: |
| left_hand_pose = None |
| right_hand_pose = None |
| else: |
| hand_pose = matrix_poses[:, 22:] |
| left_hand_pose = hand_pose[:, :15] |
| right_hand_pose = hand_pose[:, 15:] |
|
|
| n = len(body_pose) |
| outputs = [] |
| for chunk in range(int((n - 1) / batch_size) + 1): |
| chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) |
| smpl_output = self.smplh( |
| global_orient=slice_or_none(global_orient, chunk_slice), |
| body_pose=slice_or_none(body_pose, chunk_slice), |
| left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), |
| right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), |
| transl=slice_or_none(trans_all, chunk_slice)) |
|
|
| if jointstype == "vertices": |
| output_chunk = smpl_output.vertices |
| else: |
| joints = smpl_output.joints |
| output_chunk = joints |
|
|
| outputs.append(output_chunk) |
|
|
| outputs = torch.cat(outputs) |
| outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) |
|
|
| |
| outputs = smplh_to(jointstype, outputs, trans) |
|
|
| return outputs |
|
|
| def inverse(self, joints: Tensor) -> Tensor: |
| raise NotImplementedError("Cannot inverse SMPLH layer.") |
|
|
|
|
| def smplh_to(jointstype, data, trans): |
| from mGPT.utils.joints import get_root_idx |
|
|
| if "mmm" in jointstype: |
| from mGPT.utils.joints import smplh2mmm_indexes |
| indexes = smplh2mmm_indexes |
| data = data[..., indexes, :] |
|
|
| |
| if jointstype == "mmm": |
| from mGPT.utils.joints import smplh_to_mmm_scaling_factor |
| data *= smplh_to_mmm_scaling_factor |
|
|
| if jointstype == "smplmmm": |
| pass |
| elif jointstype in ["mmm", "mmmns"]: |
| |
| data = data[..., [1, 2, 0]] |
| |
| data[..., 2] = -data[..., 2] |
|
|
| elif jointstype == "smplnh": |
| from mGPT.utils.joints import smplh2smplnh_indexes |
| indexes = smplh2smplnh_indexes |
| data = data[..., indexes, :] |
| elif jointstype == "smplh": |
| pass |
| elif jointstype == "vertices": |
| pass |
| else: |
| raise NotImplementedError(f"SMPLH to {jointstype} is not implemented.") |
|
|
| if jointstype != "vertices": |
| |
| |
| root_joint_idx = get_root_idx(jointstype) |
| shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] |
| data += shift[..., None, None, :] |
|
|
| return data |
|
|