Jialin Yang
Initial release on Huggingface Spaces with Gradio UI
352b049
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import smplx
import json
import time
import pickle
from datetime import datetime
from datetime import timedelta
from . import config
from .customloss import (
body_fitting_loss_3d,
camera_fitting_loss_3d,
)
from .prior import MaxMixturePrior
@torch.no_grad()
def guess_init_3d(model_joints, j3d, joints_category="orig"):
"""Initialize the camera translation via triangle similarity, by using the torso joints .
:param model_joints: SMPL model with pre joints
:param j3d: 25x3 array of Kinect Joints
:returns: 3D vector corresponding to the estimated camera translation
"""
# get the indexed four
gt_joints = ["RHip", "LHip", "RShoulder", "LShoulder"]
gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
if joints_category == "orig":
joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints]
elif joints_category == "AMASS":
joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints]
else:
print("NO SUCH JOINTS CATEGORY!")
sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(
dim=1
)
init_t = sum_init_t / 4.0
return init_t
# SMPLIfy 3D
class SMPLify3D:
"""Implementation of SMPLify, use 3D joints."""
def __init__(
self,
smplxmodel,
step_size=1e-2,
num_iters=100,
joints_category="orig",
device=torch.device("cuda:0"),
GMM_MODEL_DIR="./joint2smpl_models/",
):
# Store options
self.device = device
self.step_size = step_size
self.num_iters = num_iters
# GMM pose prior
self.pose_prior = MaxMixturePrior(
prior_folder=GMM_MODEL_DIR, num_gaussians=8, dtype=torch.float32
).to(device)
# reLoad SMPL-X model
self.smpl = smplxmodel
self.model_faces = smplxmodel.faces_tensor.view(-1)
# select joint joint_category
self.joints_category = joints_category
if joints_category == "orig":
self.smpl_index = config.full_smpl_idx
self.corr_index = config.full_smpl_idx
elif joints_category == "AMASS":
self.smpl_index = config.amass_smpl_idx
self.corr_index = config.amass_idx
else:
self.smpl_index = None
self.corr_index = None
print("NO SUCH JOINTS CATEGORY!")
# ---- get the man function here ------
def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, fix_betas=0, if_simple_hmp_optimizes=False, num_iters=None):
"""Perform body fitting.
Input:
init_pose: SMPL pose estimate
init_betas: SMPL betas estimate
init_cam_t: Camera translation estimate
j3d: joints 3d aka keypoints
conf_3d: confidence for 3d joints
seq_ind: index of the sequence
Returns:
vertices: Vertices of optimized shape
joints: 3D joints of optimized shape
pose: SMPL pose parameters of optimized shape
betas: SMPL beta parameters of optimized shape
camera_translation: Camera translation
"""
# # # add the mesh inter-section to avoid
search_tree = None
pen_distance = None
filter_faces = None
self.t0 = datetime.now()
# Split SMPL pose to body pose and global orientation
body_pose = init_pose[:, 3:].detach().clone()
global_orient = init_pose[:, :3].detach().clone()
betas = init_betas.detach().clone()
camera_translation = init_cam_t.clone()
preserve_pose = init_pose[:, 3:].detach().clone()
# -------------Step 1: Optimize camera translation and body orientation--------
# Optimize only camera translation and body orientation
body_pose.requires_grad = False
betas.requires_grad = False
global_orient.requires_grad = True
if not if_simple_hmp_optimizes:
camera_translation.requires_grad = True
camera_opt_params = [global_orient, camera_translation]
# camera_optimizer = torch.optim.LBFGS(
# camera_opt_params,
# max_iter=self.num_iters,
# lr=self.step_size,
# line_search_fn="strong_wolfe",
# )
# for i in range(10):
# def closure():
# camera_optimizer.zero_grad()
# smpl_output = self.smpl(
# global_orient=global_orient, body_pose=body_pose, betas=betas
# )
# model_joints = smpl_output.joints
# loss = camera_fitting_loss_3d(
# model_joints,
# camera_translation,
# init_cam_t,
# j3d,
# self.joints_category,
# )
# loss.backward()
# return loss
# camera_optimizer.step(closure)
camera_optimizer = torch.optim.Adam(
camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)
)
for i in range(10):
smpl_output = self.smpl(
global_orient=global_orient, body_pose=body_pose, betas=betas
)
model_joints = smpl_output.joints
loss = camera_fitting_loss_3d(
model_joints[:, self.smpl_index],
camera_translation,
init_cam_t,
j3d[:, self.corr_index],
self.joints_category,
)
camera_optimizer.zero_grad()
loss.backward()
camera_optimizer.step()
self.t = datetime.now() - self.t0
self.t0 = datetime.now()
print(f"Step 0: Average time in seconds: {self.t/timedelta(seconds=1)}")
# Fix camera translation after optimizing camera
# --------Step 2: Optimize body joints --------------------------
# Optimize only the body pose and global orientation of the body
body_pose.requires_grad = True
global_orient.requires_grad = True
if not if_simple_hmp_optimizes:
camera_translation.requires_grad = True
# --- if we use the sequence, fix the shape
if not fix_betas:
betas.requires_grad = True
body_opt_params = [body_pose, betas, global_orient, camera_translation]
else:
betas.requires_grad = False
body_opt_params = [body_pose, global_orient, camera_translation]
num_iters = self.num_iters if num_iters is None else num_iters
body_optimizer = torch.optim.LBFGS(
body_opt_params,
max_iter=num_iters,
lr=self.step_size,
line_search_fn="strong_wolfe",
)
for i in range(num_iters):
def closure():
body_optimizer.zero_grad()
smpl_output = self.smpl(
global_orient=global_orient, body_pose=body_pose, betas=betas
)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
loss = body_fitting_loss_3d(
body_pose,
preserve_pose,
betas,
model_joints[:, self.smpl_index],
camera_translation,
j3d[:, self.corr_index],
self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
pose_preserve_weight=5.0,
use_collision=False,
model_vertices=model_vertices,
model_faces=self.model_faces,
search_tree=search_tree,
pen_distance=pen_distance,
filter_faces=filter_faces,
)
loss.backward()
return loss
body_optimizer.step(closure)
# body_optimizer = torch.optim.Adam(
# body_opt_params, lr=1.e-4, betas=(0.9, 0.999)
# )
# for i in range(num_iters):
# smpl_output = self.smpl(
# global_orient=global_orient, body_pose=body_pose, betas=betas
# )
# model_joints = smpl_output.joints
# model_vertices = smpl_output.vertices
# loss = body_fitting_loss_3d(
# body_pose,
# preserve_pose,
# betas,
# model_joints[:, self.smpl_index],
# camera_translation,
# j3d[:, self.corr_index],
# self.pose_prior,
# joints3d_conf=conf_3d,
# joint_loss_weight=600.0,
# use_collision=False,
# model_vertices=model_vertices,
# model_faces=self.model_faces,
# search_tree=search_tree,
# pen_distance=pen_distance,
# filter_faces=filter_faces,
# )
# body_optimizer.zero_grad()
# loss.backward()
# body_optimizer.step()
self.t = datetime.now() - self.t0
self.t0 = datetime.now()
print(f"Step2: Average time in seconds: {self.t/timedelta(seconds=1)}")
smpl_output = self.smpl(
global_orient=global_orient, body_pose=body_pose, betas=betas
)
vertices = smpl_output.vertices.detach()
joints = smpl_output.joints.detach()
pose = torch.cat([global_orient, body_pose], dim=-1).detach()
betas = betas.detach()
return vertices, joints, pose, betas, camera_translation