Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import smplx | |
| import json | |
| import time | |
| import pickle | |
| from datetime import datetime | |
| from datetime import timedelta | |
| from . import config | |
| from .customloss import ( | |
| body_fitting_loss_3d, | |
| camera_fitting_loss_3d, | |
| ) | |
| from .prior import MaxMixturePrior | |
| def guess_init_3d(model_joints, j3d, joints_category="orig"): | |
| """Initialize the camera translation via triangle similarity, by using the torso joints . | |
| :param model_joints: SMPL model with pre joints | |
| :param j3d: 25x3 array of Kinect Joints | |
| :returns: 3D vector corresponding to the estimated camera translation | |
| """ | |
| # get the indexed four | |
| gt_joints = ["RHip", "LHip", "RShoulder", "LShoulder"] | |
| gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] | |
| if joints_category == "orig": | |
| joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] | |
| elif joints_category == "AMASS": | |
| joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] | |
| else: | |
| print("NO SUCH JOINTS CATEGORY!") | |
| sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum( | |
| dim=1 | |
| ) | |
| init_t = sum_init_t / 4.0 | |
| return init_t | |
| # SMPLIfy 3D | |
| class SMPLify3D: | |
| """Implementation of SMPLify, use 3D joints.""" | |
| def __init__( | |
| self, | |
| smplxmodel, | |
| step_size=1e-2, | |
| num_iters=100, | |
| joints_category="orig", | |
| device=torch.device("cuda:0"), | |
| GMM_MODEL_DIR="./joint2smpl_models/", | |
| ): | |
| # Store options | |
| self.device = device | |
| self.step_size = step_size | |
| self.num_iters = num_iters | |
| # GMM pose prior | |
| self.pose_prior = MaxMixturePrior( | |
| prior_folder=GMM_MODEL_DIR, num_gaussians=8, dtype=torch.float32 | |
| ).to(device) | |
| # reLoad SMPL-X model | |
| self.smpl = smplxmodel | |
| self.model_faces = smplxmodel.faces_tensor.view(-1) | |
| # select joint joint_category | |
| self.joints_category = joints_category | |
| if joints_category == "orig": | |
| self.smpl_index = config.full_smpl_idx | |
| self.corr_index = config.full_smpl_idx | |
| elif joints_category == "AMASS": | |
| self.smpl_index = config.amass_smpl_idx | |
| self.corr_index = config.amass_idx | |
| else: | |
| self.smpl_index = None | |
| self.corr_index = None | |
| print("NO SUCH JOINTS CATEGORY!") | |
| # ---- get the man function here ------ | |
| def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, fix_betas=0, if_simple_hmp_optimizes=False, num_iters=None): | |
| """Perform body fitting. | |
| Input: | |
| init_pose: SMPL pose estimate | |
| init_betas: SMPL betas estimate | |
| init_cam_t: Camera translation estimate | |
| j3d: joints 3d aka keypoints | |
| conf_3d: confidence for 3d joints | |
| seq_ind: index of the sequence | |
| Returns: | |
| vertices: Vertices of optimized shape | |
| joints: 3D joints of optimized shape | |
| pose: SMPL pose parameters of optimized shape | |
| betas: SMPL beta parameters of optimized shape | |
| camera_translation: Camera translation | |
| """ | |
| # # # add the mesh inter-section to avoid | |
| search_tree = None | |
| pen_distance = None | |
| filter_faces = None | |
| self.t0 = datetime.now() | |
| # Split SMPL pose to body pose and global orientation | |
| body_pose = init_pose[:, 3:].detach().clone() | |
| global_orient = init_pose[:, :3].detach().clone() | |
| betas = init_betas.detach().clone() | |
| camera_translation = init_cam_t.clone() | |
| preserve_pose = init_pose[:, 3:].detach().clone() | |
| # -------------Step 1: Optimize camera translation and body orientation-------- | |
| # Optimize only camera translation and body orientation | |
| body_pose.requires_grad = False | |
| betas.requires_grad = False | |
| global_orient.requires_grad = True | |
| if not if_simple_hmp_optimizes: | |
| camera_translation.requires_grad = True | |
| camera_opt_params = [global_orient, camera_translation] | |
| # camera_optimizer = torch.optim.LBFGS( | |
| # camera_opt_params, | |
| # max_iter=self.num_iters, | |
| # lr=self.step_size, | |
| # line_search_fn="strong_wolfe", | |
| # ) | |
| # for i in range(10): | |
| # def closure(): | |
| # camera_optimizer.zero_grad() | |
| # smpl_output = self.smpl( | |
| # global_orient=global_orient, body_pose=body_pose, betas=betas | |
| # ) | |
| # model_joints = smpl_output.joints | |
| # loss = camera_fitting_loss_3d( | |
| # model_joints, | |
| # camera_translation, | |
| # init_cam_t, | |
| # j3d, | |
| # self.joints_category, | |
| # ) | |
| # loss.backward() | |
| # return loss | |
| # camera_optimizer.step(closure) | |
| camera_optimizer = torch.optim.Adam( | |
| camera_opt_params, lr=self.step_size, betas=(0.9, 0.999) | |
| ) | |
| for i in range(10): | |
| smpl_output = self.smpl( | |
| global_orient=global_orient, body_pose=body_pose, betas=betas | |
| ) | |
| model_joints = smpl_output.joints | |
| loss = camera_fitting_loss_3d( | |
| model_joints[:, self.smpl_index], | |
| camera_translation, | |
| init_cam_t, | |
| j3d[:, self.corr_index], | |
| self.joints_category, | |
| ) | |
| camera_optimizer.zero_grad() | |
| loss.backward() | |
| camera_optimizer.step() | |
| self.t = datetime.now() - self.t0 | |
| self.t0 = datetime.now() | |
| print(f"Step 0: Average time in seconds: {self.t/timedelta(seconds=1)}") | |
| # Fix camera translation after optimizing camera | |
| # --------Step 2: Optimize body joints -------------------------- | |
| # Optimize only the body pose and global orientation of the body | |
| body_pose.requires_grad = True | |
| global_orient.requires_grad = True | |
| if not if_simple_hmp_optimizes: | |
| camera_translation.requires_grad = True | |
| # --- if we use the sequence, fix the shape | |
| if not fix_betas: | |
| betas.requires_grad = True | |
| body_opt_params = [body_pose, betas, global_orient, camera_translation] | |
| else: | |
| betas.requires_grad = False | |
| body_opt_params = [body_pose, global_orient, camera_translation] | |
| num_iters = self.num_iters if num_iters is None else num_iters | |
| body_optimizer = torch.optim.LBFGS( | |
| body_opt_params, | |
| max_iter=num_iters, | |
| lr=self.step_size, | |
| line_search_fn="strong_wolfe", | |
| ) | |
| for i in range(num_iters): | |
| def closure(): | |
| body_optimizer.zero_grad() | |
| smpl_output = self.smpl( | |
| global_orient=global_orient, body_pose=body_pose, betas=betas | |
| ) | |
| model_joints = smpl_output.joints | |
| model_vertices = smpl_output.vertices | |
| loss = body_fitting_loss_3d( | |
| body_pose, | |
| preserve_pose, | |
| betas, | |
| model_joints[:, self.smpl_index], | |
| camera_translation, | |
| j3d[:, self.corr_index], | |
| self.pose_prior, | |
| joints3d_conf=conf_3d, | |
| joint_loss_weight=600.0, | |
| pose_preserve_weight=5.0, | |
| use_collision=False, | |
| model_vertices=model_vertices, | |
| model_faces=self.model_faces, | |
| search_tree=search_tree, | |
| pen_distance=pen_distance, | |
| filter_faces=filter_faces, | |
| ) | |
| loss.backward() | |
| return loss | |
| body_optimizer.step(closure) | |
| # body_optimizer = torch.optim.Adam( | |
| # body_opt_params, lr=1.e-4, betas=(0.9, 0.999) | |
| # ) | |
| # for i in range(num_iters): | |
| # smpl_output = self.smpl( | |
| # global_orient=global_orient, body_pose=body_pose, betas=betas | |
| # ) | |
| # model_joints = smpl_output.joints | |
| # model_vertices = smpl_output.vertices | |
| # loss = body_fitting_loss_3d( | |
| # body_pose, | |
| # preserve_pose, | |
| # betas, | |
| # model_joints[:, self.smpl_index], | |
| # camera_translation, | |
| # j3d[:, self.corr_index], | |
| # self.pose_prior, | |
| # joints3d_conf=conf_3d, | |
| # joint_loss_weight=600.0, | |
| # use_collision=False, | |
| # model_vertices=model_vertices, | |
| # model_faces=self.model_faces, | |
| # search_tree=search_tree, | |
| # pen_distance=pen_distance, | |
| # filter_faces=filter_faces, | |
| # ) | |
| # body_optimizer.zero_grad() | |
| # loss.backward() | |
| # body_optimizer.step() | |
| self.t = datetime.now() - self.t0 | |
| self.t0 = datetime.now() | |
| print(f"Step2: Average time in seconds: {self.t/timedelta(seconds=1)}") | |
| smpl_output = self.smpl( | |
| global_orient=global_orient, body_pose=body_pose, betas=betas | |
| ) | |
| vertices = smpl_output.vertices.detach() | |
| joints = smpl_output.joints.detach() | |
| pose = torch.cat([global_orient, body_pose], dim=-1).detach() | |
| betas = betas.detach() | |
| return vertices, joints, pose, betas, camera_translation | |