Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import print_function | |
| from __future__ import division | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| from configs import constants as _C | |
| from lib.models.layers import (MotionEncoder, MotionDecoder, TrajectoryDecoder, TrajectoryRefiner, Integrator, | |
| rollout_global_motion, reset_root_velocity, compute_camera_motion) | |
| from lib.utils.transforms import axis_angle_to_matrix | |
| class Network(nn.Module): | |
| def __init__(self, | |
| smpl, | |
| pose_dr=0.1, | |
| d_embed=512, | |
| n_layers=3, | |
| d_feat=2048, | |
| rnn_type='LSTM', | |
| **kwargs | |
| ): | |
| super().__init__() | |
| n_joints = _C.KEYPOINTS.NUM_JOINTS | |
| self.smpl = smpl | |
| in_dim = n_joints * 2 + 3 | |
| d_context = d_embed + n_joints * 3 | |
| self.mask_embedding = nn.Parameter(torch.zeros(1, 1, n_joints, 2)) | |
| # Module 1. Motion Encoder | |
| self.motion_encoder = MotionEncoder(in_dim=in_dim, | |
| d_embed=d_embed, | |
| pose_dr=pose_dr, | |
| rnn_type=rnn_type, | |
| n_layers=n_layers, | |
| n_joints=n_joints) | |
| self.trajectory_decoder = TrajectoryDecoder(d_embed=d_context, | |
| rnn_type=rnn_type, | |
| n_layers=n_layers) | |
| # Module 3. Feature Integrator | |
| self.integrator = Integrator(in_channel=d_feat + d_context, | |
| out_channel=d_context) | |
| # Module 4. Motion Decoder | |
| self.motion_decoder = MotionDecoder(d_embed=d_context, | |
| rnn_type=rnn_type, | |
| n_layers=n_layers) | |
| # Module 5. Trajectory Refiner | |
| self.trajectory_refiner = TrajectoryRefiner(d_embed=d_context, | |
| d_hidden=d_embed, | |
| rnn_type=rnn_type, | |
| n_layers=2) | |
| def compute_global_feet(self, root_world, trans): | |
| # # Compute world-coordinate motion | |
| cam_R, cam_T = compute_camera_motion(self.output, self.pred_pose[:, :, :6], root_world, trans, self.pred_cam) | |
| feet_cam = self.output.feet.reshape(self.b, self.f, -1, 3) + self.output.full_cam.reshape(self.b, self.f, 1, 3) | |
| feet_world = (cam_R.mT @ (feet_cam - cam_T.unsqueeze(-2)).mT).mT | |
| return feet_world, cam_R | |
| def forward_smpl(self, **kwargs): | |
| self.output = self.smpl(self.pred_pose, | |
| self.pred_shape, | |
| cam=self.pred_cam, | |
| return_full_pose=not self.training, | |
| **kwargs, | |
| ) | |
| from loguru import logger | |
| logger.info(f"Output Joints: {self.output.joints}") | |
| logger.info(f"Output Vertices: {self.output.vertices}") | |
| # Save joints and vertices as .npy arrays | |
| np.save('joints.npy', self.output.joints.cpu().numpy()) | |
| np.save('vertices.npy', self.output.vertices.cpu().numpy()) | |
| # Feet location in global coordinate | |
| root_world, trans = rollout_global_motion(self.pred_root, self.pred_vel) | |
| feet_world, cam_R = self.compute_global_feet(root_world, trans) | |
| # Return output | |
| output = {'feet': feet_world, | |
| 'contact': self.pred_contact, | |
| 'pose': self.pred_pose, | |
| 'betas': self.pred_shape, | |
| 'cam': self.pred_cam, | |
| 'poses_root_cam': self.output.global_orient, | |
| 'poses_root_r6d': self.pred_root, | |
| 'vel_root': self.pred_vel, | |
| 'pose_root': self.pred_root, | |
| 'verts_cam': self.output.vertices} | |
| if self.training: | |
| output.update({ | |
| 'kp3d': self.output.joints, | |
| 'kp3d_nn': self.pred_kp3d, | |
| 'full_kp2d': self.output.full_joints2d, | |
| 'weak_kp2d': self.output.weak_joints2d, | |
| 'R': cam_R, | |
| }) | |
| else: | |
| output.update({ | |
| 'poses_root_r6d': self.pred_root, | |
| 'trans_cam': self.output.full_cam, | |
| 'poses_body': self.output.body_pose}) | |
| return output | |
| def preprocess(self, x, mask): | |
| self.b, self.f = x.shape[:2] | |
| # Treat masked keypoints | |
| mask_embedding = mask.unsqueeze(-1) * self.mask_embedding | |
| _mask = mask.unsqueeze(-1).repeat(1, 1, 1, 2).reshape(self.b, self.f, -1) | |
| _mask = torch.cat((_mask, torch.zeros_like(_mask[..., :3])), dim=-1) | |
| _mask_embedding = mask_embedding.reshape(self.b, self.f, -1) | |
| _mask_embedding = torch.cat((_mask_embedding, torch.zeros_like(_mask_embedding[..., :3])), dim=-1) | |
| x[_mask] = 0.0 | |
| x = x + _mask_embedding | |
| return x | |
| def rollout(self, output, pred_root, pred_vel, return_y_up): | |
| root_world, trans_world = rollout_global_motion(pred_root, pred_vel) | |
| if return_y_up: | |
| yup2ydown = axis_angle_to_matrix(torch.tensor([[np.pi, 0, 0]])).float().to(root_world.device) | |
| root_world = yup2ydown.mT @ root_world | |
| trans_world = (yup2ydown.mT @ trans_world.unsqueeze(-1)).squeeze(-1) | |
| output.update({ | |
| 'poses_root_world': root_world, | |
| 'trans_world': trans_world, | |
| }) | |
| return output | |
| def refine_trajectory(self, output, cam_angvel, return_y_up, **kwargs): | |
| # --------- Refine trajectory --------- # | |
| update_vel = reset_root_velocity(self.smpl, self.output, self.pred_contact, self.pred_root, self.pred_vel, thr=0.5) | |
| output = self.trajectory_refiner(self.old_motion_context, update_vel, output, cam_angvel, return_y_up=return_y_up) | |
| # --------- # | |
| # Do rollout | |
| output = self.rollout(output, output['poses_root_r6d_refined'], output['vel_root_refined'], return_y_up) | |
| # --------- Compute refined feet --------- # | |
| if self.training: | |
| feet_world, cam_R = self.compute_global_feet(output['poses_root_world'], output['trans_world']) | |
| output.update({'feet_refined': feet_world}) | |
| return output | |
| def forward(self, x, inits, img_features=None, mask=None, init_root=None, cam_angvel=None, | |
| cam_intrinsics=None, bbox=None, res=None, return_y_up=False, refine_traj=True, **kwargs): | |
| x = self.preprocess(x, mask) | |
| init_kp, init_smpl = inits | |
| # --------- Inference --------- # | |
| # Stage 1. Encode motion | |
| pred_kp3d, motion_context = self.motion_encoder(x, init_kp) | |
| self.old_motion_context = motion_context.detach().clone() | |
| # Stage 2. Decode global trajectory | |
| pred_root, pred_vel = self.trajectory_decoder(motion_context, init_root, cam_angvel) | |
| # Stage 3. Integrate features | |
| if img_features is not None and self.integrator is not None: | |
| motion_context = self.integrator(motion_context, img_features) | |
| # Stage 4. Decode SMPL motion | |
| pred_pose, pred_shape, pred_cam, pred_contact = self.motion_decoder(motion_context, init_smpl) | |
| # --------- # | |
| # --------- Register predictions --------- # | |
| self.pred_kp3d = pred_kp3d | |
| self.pred_root = pred_root | |
| self.pred_vel = pred_vel | |
| self.pred_pose = pred_pose | |
| self.pred_shape = pred_shape | |
| self.pred_cam = pred_cam | |
| self.pred_contact = pred_contact | |
| # --------- # | |
| # --------- Build SMPL --------- # | |
| output = self.forward_smpl(cam_intrinsics=cam_intrinsics, bbox=bbox, res=res) | |
| # --------- # | |
| # --------- Refine trajectory --------- # | |
| if refine_traj: | |
| output = self.refine_trajectory(output, cam_angvel, return_y_up) | |
| else: | |
| output = self.rollout(output, self.pred_root, self.pred_vel, return_y_up) | |
| # --------- # | |
| return output |