| |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| import genmo.utils.matrix as matrix |
| from genmo.utils.rotation_conversions import ( |
| axis_angle_to_matrix, |
| matrix_to_axis_angle, |
| matrix_to_rotation_6d, |
| rotation_6d_to_matrix, |
| ) |
| from genmo.utils.torch_transform import ( |
| angle_axis_to_quaternion, |
| get_y_heading_q, |
| quat_apply, |
| quat_conjugate, |
| quat_mul, |
| quaternion_to_angle_axis, |
| ) |
| from third_party.GVHMR.hmr4d.utils.geo.augment_noisy_pose import gaussian_augment |
| from third_party.GVHMR.hmr4d.utils.geo.hmr_global import ( |
| get_local_transl_vel, |
| get_static_joint_mask, |
| rollout_local_transl_vel, |
| ) |
| from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx |
|
|
| from . import stats_compose |
|
|
|
|
| class EnDecoder(nn.Module): |
| def __init__( |
| self, |
| stats_name="DEFAULT_01", |
| encode_type="gvhmr", |
| feature_arr=None, |
| stats_arr=None, |
| noise_pose_k=10, |
| clip_std=False, |
| ): |
| super().__init__() |
|
|
| if encode_type in ["gvhmr", "humanml3d"]: |
| feature_arr = [encode_type] |
| stats_arr = [stats_name] |
|
|
| |
| self.FEATURE_DIMS = { |
| "gvhmr": 151, |
| "humanml3d": 143, |
| } |
|
|
| |
| self.stats_dict = {} |
|
|
| for feature, stats_name in zip(feature_arr, stats_arr): |
| stats = getattr(stats_compose, stats_name) |
| mean = torch.tensor(stats["mean"]).float() |
| std = torch.tensor(stats["std"]).float() |
|
|
| feature_dim = self.FEATURE_DIMS[feature] |
| if stats_name != "DEFAULT_01": |
| assert mean.shape[-1] == feature_dim |
| assert std.shape[-1] == feature_dim |
|
|
| if clip_std: |
| std = torch.clamp(std, 0.1, 1) |
|
|
| self.stats_dict[feature] = {"mean": mean, "std": std} |
|
|
| |
| self.feature_arr = feature_arr |
| self.stats_arr = stats_arr |
| self.clip_std = clip_std |
|
|
| |
| self.noise_pose_k = noise_pose_k |
| self.encode_type = encode_type |
| self.obs_indices_dict = None |
|
|
| |
| self.smplx_model = make_smplx("supermotion_v437coco17") |
| parents = self.smplx_model.parents[:22] |
| self.register_buffer("parents_tensor", parents, False) |
| self.parents = parents.tolist() |
|
|
| def normalize(self, x, feature_type): |
| """Normalize input using stats for specific feature type""" |
| stats = self.stats_dict[feature_type] |
| return (x - stats["mean"].to(x)) / stats["std"].to(x) |
|
|
| def denormalize(self, x_norm, feature_type): |
| """Denormalize input using stats for specific feature type""" |
| stats = self.stats_dict[feature_type] |
| return x_norm * stats["std"].to(x_norm) + stats["mean"].to(x_norm) |
|
|
| def get_noisyobs(self, data, return_type="r6d"): |
| """ |
| Noisy observation contains local pose with noise |
| Args: |
| data (dict): |
| body_pose: (B, L, J*3) or (B, L, J, 3) |
| Returns: |
| noisy_bosy_pose: (B, L, J, 6) or (B, L, J, 3) or (B, L, 3, 3) depends on return_type |
| """ |
| body_pose = data["body_pose"] |
| B, L, _ = body_pose.shape |
| body_pose = body_pose.reshape(B, L, -1, 3) |
|
|
| |
| return_mapping = {"R": 0, "r6d": 1, "aa": 2} |
| return_id = return_mapping[return_type] |
| noisy_bosy_pose = gaussian_augment(body_pose, self.noise_pose_k, to_R=True)[ |
| return_id |
| ] |
| return noisy_bosy_pose |
|
|
| def normalize_body_pose_r6d(self, body_pose_r6d): |
| """body_pose_r6d: (B, L, {J*6}/{J, 6}) -> (B, L, J*6)""" |
| B, L = body_pose_r6d.shape[:2] |
| body_pose_r6d = body_pose_r6d.reshape(B, L, -1) |
| if ( |
| self.stats_dict[self.encode_type]["mean"].shape[-1] == 1 |
| ): |
| return body_pose_r6d |
| body_pose_r6d = ( |
| body_pose_r6d - self.stats_dict["gvhmr"]["mean"] |
| ) / self.stats_dict["gvhmr"]["std"] |
| return body_pose_r6d |
|
|
| def fk_v2( |
| self, body_pose, betas, global_orient=None, transl=None, get_intermediate=False |
| ): |
| """ |
| Args: |
| body_pose: (B, L, 63) |
| betas: (B, L, 10) |
| global_orient: (B, L, 3) |
| Returns: |
| joints: (B, L, 22, 3) |
| """ |
| B, L = body_pose.shape[:2] |
| if global_orient is None: |
| global_orient = torch.zeros((B, L, 3), device=body_pose.device) |
| aa = torch.cat([global_orient, body_pose], dim=-1).reshape(B, L, -1, 3) |
| rotmat = axis_angle_to_matrix(aa) |
|
|
| skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] |
| local_skeleton = skeleton - skeleton[:, :, self.parents_tensor] |
| local_skeleton = torch.cat( |
| [skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2 |
| ) |
|
|
| if transl is not None: |
| local_skeleton[..., 0, :] += transl |
|
|
| mat = matrix.get_TRS(rotmat, local_skeleton) |
| fk_mat = matrix.forward_kinematics(mat, self.parents) |
| joints = matrix.get_position(fk_mat) |
| if not get_intermediate: |
| return joints |
| else: |
| return joints, mat, fk_mat |
|
|
| def get_local_pos(self, betas): |
| skeleton = self.smplx_model.get_skeleton(betas)[..., :22, :] |
| local_skeleton = skeleton - skeleton[:, :, self.parents_tensor] |
| local_skeleton = torch.cat( |
| [skeleton[:, :, :1], local_skeleton[:, :, 1:]], dim=2 |
| ) |
| return local_skeleton |
|
|
| def get_static_gt(self, inputs, vel_thr): |
| joint_ids = [ |
| 7, |
| 10, |
| 8, |
| 11, |
| 20, |
| 21, |
| ] |
| gt_w_j3d = self.fk_v2(**inputs["smpl_params_w"]) |
| static_gt = get_static_joint_mask( |
| gt_w_j3d, vel_thr=vel_thr, repeat_last=True |
| ) |
| static_gt = static_gt[:, :, joint_ids].float() |
| return static_gt |
|
|
| def encode(self, inputs): |
| """Composite encoder that combines multiple feature types""" |
| encoded_features = [] |
|
|
| for feature in self.feature_arr: |
| if feature == "gvhmr": |
| encoded = self.encode_gvhmr(inputs) |
| elif feature == "humanml3d": |
| encoded = self.encode_humanml3d(inputs) |
| encoded_features.append(encoded) |
|
|
| |
| return torch.cat(encoded_features, dim=-1) |
|
|
| def encode_humanml3d(self, inputs): |
| """ |
| definition: { |
| body_pose_r6d, # (B, L, (J-1)*6) -> 0:126 |
| betas, # (B, L, 10) -> 126:136 |
| root_data, # (B, L, 10) -> 136:143 |
| } |
| """ |
| self.obs_indices_dict = { |
| "body_pose": torch.arange(126), |
| "betas": torch.arange(126, 136), |
| "root_data": torch.arange(136, 143), |
| } |
| B, L = inputs["smpl_params_w"]["body_pose"].shape[:2] |
| |
| smpl_params_w = inputs["smpl_params_w"] |
| body_pose = smpl_params_w["body_pose"].reshape(B, L, 21, 3) |
| body_pose_r6d = matrix_to_rotation_6d(axis_angle_to_matrix(body_pose)).flatten( |
| -2 |
| ) |
| betas = smpl_params_w["betas"] |
| global_orient = smpl_params_w["global_orient"] |
| trans = smpl_params_w["transl"].clone() |
|
|
| root_quat = angle_axis_to_quaternion(global_orient) |
| heading_quat = get_y_heading_q(root_quat) |
| heading_quat_inv = quat_conjugate(heading_quat) |
| root_quat_wo_heading = quat_mul(heading_quat_inv, root_quat) |
| |
| root_quat_wo_heading = quaternion_to_angle_axis(root_quat_wo_heading) |
|
|
| init_heading_quat_inv = heading_quat_inv[:, [0]].repeat(1, L, 1) |
|
|
| """XZ at origin""" |
| root_y = trans[..., [1]] |
| root_pos_init = trans[:, [0]] |
| root_pose_init_xz = root_pos_init * torch.tensor([1, 0, 1]).to(root_pos_init) |
| trans = trans - root_pose_init_xz |
|
|
| """All initially face Z+""" |
| trans = quat_apply(init_heading_quat_inv, trans) |
| heading_quat = quat_mul( |
| heading_quat, init_heading_quat_inv |
| ) |
| heading_quat_inv = quat_conjugate(heading_quat) |
|
|
| """Root Linear Velocity""" |
| |
| velocity = trans[:, 1:] - trans[:, :-1] |
| velocity = torch.cat([velocity, velocity[:, [-1]]], axis=1) |
| |
| velocity = quat_apply(heading_quat_inv, velocity) |
| l_velocity = velocity[..., [0, 2]] |
| """Root Angular Velocity""" |
| |
| r_angles = torch.arctan2(heading_quat[..., 2:3], heading_quat[..., :1]) * 2 |
| r_velocity = r_angles[:, 1:] - r_angles[:, :-1] |
| r_velocity[r_velocity > np.pi] -= 2 * np.pi |
| r_velocity[r_velocity < -np.pi] += 2 * np.pi |
| r_velocity = torch.cat([r_velocity, r_velocity[:, [-1]]], axis=1) |
|
|
| root_data = torch.cat( |
| [r_velocity, l_velocity, root_y, root_quat_wo_heading], axis=-1 |
| ) |
| |
| x = torch.cat([body_pose_r6d, betas, root_data], dim=-1) |
| return self.normalize(x, "humanml3d") |
|
|
| def encode_gvhmr(self, inputs): |
| """ |
| definition: { |
| body_pose_r6d, # (B, L, (J-1)*6) -> 0:126 |
| betas, # (B, L, 10) -> 126:136 |
| global_orient_r6d, # (B, L, 6) -> 136:142 incam |
| global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv |
| local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord |
| } |
| """ |
| self.obs_indices_dict = { |
| "body_pose": torch.arange(126), |
| "betas": torch.arange(126, 136), |
| "global_orient": torch.arange(136, 142), |
| "global_orient_gv": torch.arange(142, 148), |
| "local_transl_vel": torch.arange(148, 151), |
| } |
|
|
| B, L = inputs["smpl_params_c"]["body_pose"].shape[:2] |
| |
| smpl_params_c = inputs["smpl_params_c"] |
| body_pose = smpl_params_c["body_pose"].reshape(B, L, 21, 3) |
| body_pose_r6d = matrix_to_rotation_6d(axis_angle_to_matrix(body_pose)).flatten( |
| -2 |
| ) |
| betas = smpl_params_c["betas"] |
| global_orient_R = axis_angle_to_matrix(smpl_params_c["global_orient"]) |
| global_orient_r6d = matrix_to_rotation_6d(global_orient_R) |
|
|
| |
| R_c2gv = inputs["R_c2gv"] |
| global_orient_gv_r6d = matrix_to_rotation_6d(R_c2gv @ global_orient_R) |
|
|
| |
| smpl_params_w = inputs["smpl_params_w"] |
| local_transl_vel = get_local_transl_vel( |
| smpl_params_w["transl"], smpl_params_w["global_orient"] |
| ) |
| if False: |
| transl_recover = rollout_local_transl_vel( |
| local_transl_vel, |
| smpl_params_w["global_orient"], |
| smpl_params_w["transl"][:, [0]], |
| ) |
| print((transl_recover - smpl_params_w["transl"]).abs().max()) |
|
|
| |
| x = torch.cat( |
| [ |
| body_pose_r6d, |
| betas, |
| global_orient_r6d, |
| global_orient_gv_r6d, |
| local_transl_vel, |
| ], |
| dim=-1, |
| ) |
| return self.normalize(x, "gvhmr") |
|
|
| def encode_translw(self, inputs): |
| """ |
| definition: { |
| body_pose_r6d, # (B, L, (J-1)*6) -> 0:126 |
| betas, # (B, L, 10) -> 126:136 |
| global_orient_r6d, # (B, L, 6) -> 136:142 incam |
| global_orient_gv_r6d: # (B, L, 6) -> 142:148 gv |
| local_transl_vel, # (B, L, 3) -> 148:151, smpl-coord |
| } |
| """ |
| |
| smpl_params_w = inputs["smpl_params_w"] |
| local_transl_vel = get_local_transl_vel( |
| smpl_params_w["transl"], smpl_params_w["global_orient"] |
| ) |
|
|
| |
| x = local_transl_vel |
| mean = self.stats_dict["gvhmr"]["mean"][-3:].to(x) |
| std = self.stats_dict["gvhmr"]["std"][-3:].to(x) |
| x_norm = (x - mean) / std |
| return x_norm |
|
|
| def decode_translw(self, x_norm): |
| std = self.stats_dict["gvhmr"]["std"][-3:].to(x_norm) |
| mean = self.stats_dict["gvhmr"]["mean"][-3:].to(x_norm) |
| return x_norm * std + mean |
|
|
| def decode(self, x_norm): |
| """Composite decoder that handles multiple feature types""" |
| current_idx = 0 |
| decoded_outputs = {} |
|
|
| for feature in self.feature_arr: |
| feature_size = self.FEATURE_DIMS[feature] |
| feature_norm = x_norm[..., current_idx : current_idx + feature_size] |
|
|
| if feature == "gvhmr": |
| decoded = self.decode_gvhmr(feature_norm) |
| elif feature == "humanml3d": |
| decoded = self.decode_humanml3d(feature_norm) |
|
|
| decoded_outputs.update(decoded) |
| current_idx += feature_size |
|
|
| return decoded_outputs |
|
|
| def decode_humanml3d(self, x_norm): |
| """x_norm: (B, L, C)""" |
| B, L, C = x_norm.shape |
| x = self.denormalize(x_norm, "humanml3d") |
|
|
| body_pose_r6d = x[:, :, :126] |
| betas = x[:, :, 126:136] |
| root_data = x[:, :, 136:143] |
|
|
| body_pose = matrix_to_axis_angle( |
| rotation_6d_to_matrix(body_pose_r6d.reshape(B, L, -1, 6)) |
| ) |
| body_pose = body_pose.flatten(-2) |
| offset = self.smplx_model.get_skeleton(betas)[:, :, 0] |
|
|
| rot_vel = root_data[..., 0] |
| r_rot_ang = torch.zeros_like(rot_vel).to(root_data.device) |
| """Get Y-axis rotation from rotation velocity""" |
| r_rot_ang[..., 1:] = rot_vel[..., :-1] |
| r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) |
| r_rot_quat = torch.zeros(root_data.shape[:-1] + (4,)).to(root_data) |
| r_rot_quat[..., 0] = torch.cos(r_rot_ang / 2) |
| r_rot_quat[..., 2] = torch.sin(r_rot_ang / 2) |
|
|
| r_pos = torch.zeros(root_data.shape[:-1] + (3,)).to(root_data) |
| r_pos[..., 1:, [0, 2]] = root_data[..., :-1, 1:3] |
| """Add Y-axis rotation to root position""" |
| r_pos = quat_apply(r_rot_quat, r_pos) |
| r_pos = torch.cumsum(r_pos, dim=-2) |
| r_pos[..., 1] = root_data[..., 3] |
| |
| |
| |
| |
| |
| global_orient_w = quaternion_to_angle_axis( |
| quat_mul(r_rot_quat, angle_axis_to_quaternion(root_data[..., 4:])) |
| ) |
|
|
| output = { |
| "body_pose": body_pose, |
| "betas": betas, |
| "global_orient_w": global_orient_w, |
| "transl_w": r_pos, |
| "offset": offset, |
| } |
|
|
| return output |
|
|
| def decode_gvhmr(self, x_norm): |
| """x_norm: (B, L, C)""" |
| B, L, C = x_norm.shape |
| x = self.denormalize(x_norm, "gvhmr") |
|
|
| body_pose_r6d = x[:, :, :126] |
| betas = x[:, :, 126:136] |
| global_orient_r6d = x[:, :, 136:142] |
| global_orient_gv_r6d = x[:, :, 142:148] |
| local_transl_vel = x[:, :, 148:151] |
|
|
| body_pose = matrix_to_axis_angle( |
| rotation_6d_to_matrix(body_pose_r6d.reshape(B, L, -1, 6)) |
| ) |
| body_pose = body_pose.flatten(-2) |
| global_orient_c = matrix_to_axis_angle(rotation_6d_to_matrix(global_orient_r6d)) |
| global_orient_gv = matrix_to_axis_angle( |
| rotation_6d_to_matrix(global_orient_gv_r6d) |
| ) |
|
|
| offset = self.smplx_model.get_skeleton(betas)[:, :, 0] |
| output = { |
| "body_pose": body_pose, |
| "betas": betas, |
| "global_orient": global_orient_c, |
| "global_orient_gv": global_orient_gv, |
| "local_transl_vel": local_transl_vel, |
| "offset": offset, |
| } |
|
|
| return output |
|
|
| def get_motion_dim(self): |
| """Calculate total dimension based on enabled features""" |
| return sum(self.FEATURE_DIMS[feature] for feature in self.feature_arr) |
|
|
| def get_obs_indices(self, obs): |
| return self.obs_indices_dict[obs] |
|
|