| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import sys |
| |
|
| | sys.path.append(".") |
| |
|
| | import copy |
| | import math |
| | import os |
| | import os.path as osp |
| | import pdb |
| | import pickle |
| | import time |
| | import traceback |
| |
|
| | import ipdb |
| | import numpy as np |
| | import torch |
| | import trimesh |
| | from core.models.rendering.mesh_utils import Mesh |
| | from core.models.rendering.skinnings.constant import SMPLX_JOINTS |
| | from core.models.rendering.smplx import smplx |
| | from pytorch3d.io import save_ply |
| | from tqdm import tqdm |
| |
|
| |
|
| | class EasySMPLX: |
| | """Easy forward smpl-x model""" |
| |
|
| | def __init__( |
| | self, |
| | human_model_path, |
| | shape_param_dim=10, |
| | expr_param_dim=100, |
| | gender="neutral", |
| | ): |
| | self.human_model_path = human_model_path |
| | self.shape_param_dim = shape_param_dim |
| | self.expr_param_dim = expr_param_dim |
| | self.gender = gender |
| |
|
| | self._init_smplx_layers() |
| | self._load_vertex_indices() |
| | self._init_joint_info() |
| | self._init_neutral_poses() |
| |
|
| | self.smplx_layer = copy.deepcopy(self.layer[self.gender]) |
| |
|
| | def to(self, device): |
| | self.smplx_layer.to(device) |
| |
|
| | def _add_cavity(self): |
| | lip_vertex_idx = [2844, 2855, 8977, 1740, 1730, 1789, 8953, 2892] |
| | is_cavity = np.zeros((self.vertex_num), dtype=np.float32) |
| | is_cavity[lip_vertex_idx] = 1.0 |
| |
|
| | cavity_face = [[0, 1, 7], [1, 2, 7], [2, 3, 5], [3, 4, 5], [2, 5, 6], [2, 6, 7]] |
| | face_new = list(self.face_orig) |
| | for face in cavity_face: |
| | v1, v2, v3 = face |
| | face_new.append( |
| | [lip_vertex_idx[v1], lip_vertex_idx[v2], lip_vertex_idx[v3]] |
| | ) |
| | face_new = np.array(face_new, dtype=np.int64) |
| |
|
| | return is_cavity, face_new |
| |
|
| | def _get_expr_vertex_idx(self): |
| | """SMPLX + FLAME2019 Version: Retrieve related vertices ID according to LBS weights.""" |
| | |
| | with open( |
| | osp.join(self.human_model_path, "flame", "2019", "generic_model.pkl"), "rb" |
| | ) as f: |
| | flame_2019 = pickle.load(f, encoding="latin1") |
| |
|
| | |
| | vertex_idxs = np.where( |
| | (flame_2019["shapedirs"][:, :, 300 : 300 + self.expr_param_dim] != 0).sum( |
| | (1, 2) |
| | ) |
| | > 0 |
| | )[ |
| | 0 |
| | ] |
| | flame_joints_name = ("Neck", "Head", "Jaw", "L_Eye", "R_Eye") |
| | expr_vertex_idx = [] |
| | flame_vertex_num = flame_2019["v_template"].shape[0] |
| | is_neck_eye = torch.zeros((flame_vertex_num)).float() |
| | is_neck_eye[ |
| | flame_2019["weights"].argmax(1) == flame_joints_name.index("Neck") |
| | ] = 1 |
| | is_neck_eye[ |
| | flame_2019["weights"].argmax(1) == flame_joints_name.index("L_Eye") |
| | ] = 1 |
| | is_neck_eye[ |
| | flame_2019["weights"].argmax(1) == flame_joints_name.index("R_Eye") |
| | ] = 1 |
| | for idx in vertex_idxs: |
| | if is_neck_eye[idx]: |
| | continue |
| | expr_vertex_idx.append(idx) |
| |
|
| | expr_vertex_idx = np.array(expr_vertex_idx) |
| | expr_vertex_idx = self.face_vertex_idx[expr_vertex_idx] |
| |
|
| | return expr_vertex_idx |
| |
|
| | def _init_joint_info(self): |
| | """Initialize joint information and part mappings""" |
| | self.joint_num = 55 |
| | self.joints_name = SMPLX_JOINTS |
| | self.root_joint_idx = self.joints_name.index("Pelvis") |
| | self._init_joint_part_mappings() |
| |
|
| | def _load_vertex_indices(self): |
| | """Load vertex indices for different body parts""" |
| |
|
| | self.vertex_num = 10475 |
| |
|
| | self.face_orig = self.layer["neutral"].faces.astype(np.int64) |
| | self.is_cavity, self.face = self._add_cavity() |
| | with open( |
| | osp.join(self.human_model_path, "smplx", "MANO_SMPLX_vertex_ids.pkl"), "rb" |
| | ) as f: |
| | hand_vertex_idx = pickle.load(f, encoding="latin1") |
| |
|
| | self.rhand_vertex_idx = hand_vertex_idx["right_hand"] |
| | self.lhand_vertex_idx = hand_vertex_idx["left_hand"] |
| | self.expr_vertex_idx = self._get_expr_vertex_idx() |
| |
|
| | def _init_joint_part_mappings(self): |
| | """Initialize mappings between joints and body parts""" |
| | self.joint_part = { |
| | "body": range( |
| | self.joints_name.index("Pelvis"), self.joints_name.index("R_Wrist") + 1 |
| | ), |
| | "face": range( |
| | self.joints_name.index("Jaw"), self.joints_name.index("R_Eye") + 1 |
| | ), |
| | "lhand": range( |
| | self.joints_name.index("L_Index_1"), |
| | self.joints_name.index("L_Thumb_3") + 1, |
| | ), |
| | "rhand": range( |
| | self.joints_name.index("R_Index_1"), |
| | self.joints_name.index("R_Thumb_3") + 1, |
| | ), |
| | "lower_body": [ |
| | self.joints_name.index("Pelvis"), |
| | self.joints_name.index("R_Hip"), |
| | self.joints_name.index("L_Hip"), |
| | self.joints_name.index("R_Knee"), |
| | self.joints_name.index("L_Knee"), |
| | self.joints_name.index("R_Ankle"), |
| | self.joints_name.index("L_Ankle"), |
| | self.joints_name.index("R_Foot"), |
| | self.joints_name.index("L_Foot"), |
| | ], |
| | "upper_body": [ |
| | self.joints_name.index("Spine_1"), |
| | self.joints_name.index("Spine_2"), |
| | self.joints_name.index("Spine_3"), |
| | self.joints_name.index("L_Collar"), |
| | self.joints_name.index("R_Collar"), |
| | self.joints_name.index("L_Shoulder"), |
| | self.joints_name.index("R_Shoulder"), |
| | self.joints_name.index("L_Elbow"), |
| | self.joints_name.index("R_Elbow"), |
| | self.joints_name.index("L_Wrist"), |
| | self.joints_name.index("R_Wrist"), |
| | ], |
| | } |
| |
|
| | def _init_smplx_layers(self): |
| | """Initialize SMPLX layers with appropriate parameters""" |
| | layer_args = { |
| | "create_global_orient": False, |
| | "create_body_pose": False, |
| | "create_left_hand_pose": False, |
| | "create_right_hand_pose": False, |
| | "create_jaw_pose": False, |
| | "create_leye_pose": False, |
| | "create_reye_pose": False, |
| | "create_betas": False, |
| | "create_expression": False, |
| | "create_transl": False, |
| | } |
| |
|
| | human_model_path = "./pretrained_models/human_model_files" |
| | use_face_contour = True |
| |
|
| | self.layer = { |
| | gender: smplx.create( |
| | human_model_path, |
| | "smplx", |
| | gender=gender, |
| | num_betas=10, |
| | num_expression_coeffs=100, |
| | use_pca=False, |
| | use_face_contour=use_face_contour, |
| | flat_hand_mean=True, |
| | **layer_args, |
| | ) |
| | for gender in ["neutral"] |
| | } |
| |
|
| | self.face_vertex_idx = np.load( |
| | osp.join(self.human_model_path, "smplx", "SMPL-X__FLAME_vertex_ids.npy") |
| | ) |
| |
|
| | if use_face_contour: |
| | print("Using FLAME expression") |
| | self.layer = { |
| | gender: self._get_expr_from_flame(self.layer[gender]) |
| | for gender in ["neutral"] |
| | } |
| | else: |
| | print("Using basic SMPLX without FLAME expression") |
| |
|
| | def _get_expr_from_flame(self, smplx_layer): |
| | """Load expression parameters from FLAME model""" |
| | flame_layer = smplx.create( |
| | self.human_model_path, |
| | "flame", |
| | gender="neutral", |
| | num_betas=self.shape_param_dim, |
| | num_expression_coeffs=self.expr_param_dim, |
| | ) |
| | smplx_layer.expr_dirs[self.face_vertex_idx, :, :] = flame_layer.expr_dirs |
| | return smplx_layer |
| |
|
| | def _init_neutral_poses(self): |
| | """Initialize neutral pose configurations""" |
| | body_joints = len(self.joint_part["body"]) - 1 |
| | self.neutral_body_pose = torch.zeros((body_joints, 3)) |
| |
|
| | angle = math.pi / 6 |
| | self.neutral_body_pose[15] = torch.FloatTensor([0, 0, -angle]) |
| | self.neutral_body_pose[16] = torch.FloatTensor([0, 0, angle]) |
| |
|
| | @torch.no_grad() |
| | def __call__(self, smplx_data, device="cpu"): |
| |
|
| | shape_param = smplx_data["betas"] |
| |
|
| | batch_size = shape_param.shape[0] |
| |
|
| | zero_pose = torch.zeros((batch_size, 3)).float().to(device) |
| | neutral_body_pose = ( |
| | self.neutral_body_pose.view(1, -1).repeat(batch_size, 1).to(device) |
| | ) |
| | zero_hand_pose = ( |
| | torch.zeros((batch_size, len(self.joint_part["lhand"]) * 3)) |
| | .float() |
| | .to(device) |
| | ) |
| | zero_expr = torch.zeros((batch_size, self.expr_param_dim)).float().to(device) |
| |
|
| | jaw_pose = torch.zeros((batch_size, 3)).float().to(device) |
| |
|
| | shape_param = shape_param |
| | face_offset = None |
| | joint_offset = None |
| |
|
| | output = self.smplx_layer( |
| | global_orient=zero_pose, |
| | body_pose=neutral_body_pose, |
| | left_hand_pose=zero_hand_pose, |
| | right_hand_pose=zero_hand_pose, |
| | jaw_pose=jaw_pose, |
| | leye_pose=zero_pose, |
| | reye_pose=zero_pose, |
| | expression=zero_expr, |
| | betas=shape_param, |
| | face_offset=face_offset, |
| | joint_offset=joint_offset, |
| | ) |
| |
|
| | vertices = output.vertices.squeeze(0) |
| | faces = self.smplx_layer.faces |
| |
|
| | print(vertices.shape) |
| |
|
| | mesh = trimesh.Trimesh(vertices.detach().cpu().numpy(), faces, process=False) |
| |
|
| | return mesh |
| |
|
| |
|
| | def basename(path): |
| | pre_name = os.path.basename(path).split(".")[0] |
| |
|
| | return pre_name |
| |
|
| |
|
| | def multi_process(worker, items, **kwargs): |
| | """ |
| | worker worker function to process items |
| | """ |
| |
|
| | nodes = kwargs["nodes"] |
| | dirs = kwargs["dirs"] |
| |
|
| | bucket = int( |
| | np.ceil(len(items) / nodes) |
| | ) |
| |
|
| | print("Total Nodes:", nodes) |
| | print("Save Path:", dirs) |
| | rank = int(os.environ.get("RANK", 0)) |
| |
|
| | print("Current Rank:", rank) |
| |
|
| | kwargs["RANK"] = rank |
| |
|
| | if rank == nodes - 1: |
| | output_dir = worker(items[bucket * rank :], **kwargs) |
| | else: |
| | output_dir = worker(items[bucket * rank : bucket * (rank + 1)], **kwargs) |
| |
|
| | if rank == 0 and nodes > 1: |
| | timesleep = kwargs.get("timesleep", 3600) |
| | time.sleep(timesleep) |
| |
|
| |
|
| | def run_sampling(items, **params): |
| | """run pc sampling.""" |
| |
|
| | output_dir = params["dirs"] |
| | debug = params["debug"] |
| |
|
| | smplx_hand_vertex = params["smplx_hand_vertex"] |
| |
|
| | sampling_counts = 80000 |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | if debug: |
| | items = items[:100] |
| |
|
| | process_valid = [] |
| |
|
| | for item in tqdm(items, desc="Processing..."): |
| | print(item) |
| | if ".ply" not in item: |
| | continue |
| | gs_ply = item |
| | baseid = basename(gs_ply) |
| |
|
| | mesh_root = f"./exps/smplx_output/{baseid}.obj" |
| | mvs_recon_root = f"./exps/mvs_recon/{baseid}.obj" |
| |
|
| | smplx_mesh = trimesh.load_mesh(mesh_root, process=False) |
| | vertices = smplx_mesh.vertices |
| | hand_vertices = torch.from_numpy( |
| | np.asarray(vertices[smplx_hand_vertex]) |
| | ).float() |
| |
|
| | mesh = Mesh.load_obj(mvs_recon_root, device="cpu") |
| | pts = mesh.sample_surface(sampling_counts) |
| | sampling_pts = torch.cat([pts, hand_vertices], dim=0) |
| |
|
| | save_ply_path = os.path.join(output_dir, f"{baseid}_{sampling_counts}.ply") |
| | save_ply(save_ply_path, sampling_pts) |
| |
|
| | return output_dir |
| |
|
| |
|
| | def get_parse(): |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="") |
| | parser.add_argument("-i", "--input", required=True, help="input path") |
| | parser.add_argument("-o", "--output", required=True, help="output path") |
| | parser.add_argument("--nodes", default=1, type=int, help="how many workload?") |
| | parser.add_argument("--debug", action="store_true", help="debug tag") |
| | parser.add_argument("--txt", default=None, type=str) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | if __name__ == "__main__": |
| | opt = get_parse() |
| |
|
| | smplx_model = EasySMPLX("./pretrained_models/human_model_files", 10, 100) |
| | rhand_vertex_id = smplx_model.rhand_vertex_idx.tolist() |
| | lhand_vertex_id = smplx_model.lhand_vertex_idx.tolist() |
| | smplx_hand_vertex = rhand_vertex_id + lhand_vertex_id |
| |
|
| | |
| | if opt.txt == None: |
| | available_items = os.listdir(opt.input) |
| | available_items = [os.path.join(opt.input, item) for item in available_items] |
| | else: |
| | available_items = [] |
| | with open(opt.txt) as reader: |
| | for line in reader: |
| | available_items.append(line.strip()) |
| |
|
| | available_items = [ |
| | os.path.join(opt.input, item) for item in available_items |
| | ] |
| |
|
| | multi_process( |
| | worker=run_sampling, |
| | items=available_items, |
| | dirs=opt.output, |
| | nodes=opt.nodes, |
| | debug=opt.debug, |
| | smplx_hand_vertex=smplx_hand_vertex, |
| | ) |
| |
|