LHMPP / scripts /mvs_render /point_cloud_sample.py
Lingteng Qiu (邱陵腾)
rm assets & wheels
434b0b0
# -*- coding: utf-8 -*-
# @Organization : Tongyi Lab, Alibaba
# @Author : Lingteng Qiu
# @Email : 220019047@link.cuhk.edu.cn
# @Time : 2025-08-31 10:02:15
# @Function : Point cloud sampling and EasySMPLX wrapper
import sys
sys.path.append(".")
import copy
import math
import os
import os.path as osp
import pdb
import pickle
import time
import traceback
import ipdb
import numpy as np
import torch
import trimesh
from core.models.rendering.mesh_utils import Mesh
from core.models.rendering.skinnings.constant import SMPLX_JOINTS
from core.models.rendering.smplx import smplx
from pytorch3d.io import save_ply
from tqdm import tqdm
class EasySMPLX:
"""Easy forward smpl-x model"""
def __init__(
self,
human_model_path,
shape_param_dim=10, # smpl beta
expr_param_dim=100, # flame expression
gender="neutral",
):
self.human_model_path = human_model_path
self.shape_param_dim = shape_param_dim
self.expr_param_dim = expr_param_dim
self.gender = gender
self._init_smplx_layers()
self._load_vertex_indices()
self._init_joint_info()
self._init_neutral_poses()
self.smplx_layer = copy.deepcopy(self.layer[self.gender])
def to(self, device):
self.smplx_layer.to(device)
def _add_cavity(self):
lip_vertex_idx = [2844, 2855, 8977, 1740, 1730, 1789, 8953, 2892]
is_cavity = np.zeros((self.vertex_num), dtype=np.float32)
is_cavity[lip_vertex_idx] = 1.0
cavity_face = [[0, 1, 7], [1, 2, 7], [2, 3, 5], [3, 4, 5], [2, 5, 6], [2, 6, 7]]
face_new = list(self.face_orig)
for face in cavity_face:
v1, v2, v3 = face
face_new.append(
[lip_vertex_idx[v1], lip_vertex_idx[v2], lip_vertex_idx[v3]]
)
face_new = np.array(face_new, dtype=np.int64)
return is_cavity, face_new
def _get_expr_vertex_idx(self):
"""SMPLX + FLAME2019 Version: Retrieve related vertices ID according to LBS weights."""
# Load FLAME 2019 model
with open(
osp.join(self.human_model_path, "flame", "2019", "generic_model.pkl"), "rb"
) as f:
flame_2019 = pickle.load(f, encoding="latin1")
# Identify vertices influenced by expression parameters
vertex_idxs = np.where(
(flame_2019["shapedirs"][:, :, 300 : 300 + self.expr_param_dim] != 0).sum(
(1, 2)
)
> 0
)[
0
] # FLAME.SHAPE_SPACE_DIM == 300
flame_joints_name = ("Neck", "Head", "Jaw", "L_Eye", "R_Eye")
expr_vertex_idx = []
flame_vertex_num = flame_2019["v_template"].shape[0]
is_neck_eye = torch.zeros((flame_vertex_num)).float()
is_neck_eye[
flame_2019["weights"].argmax(1) == flame_joints_name.index("Neck")
] = 1
is_neck_eye[
flame_2019["weights"].argmax(1) == flame_joints_name.index("L_Eye")
] = 1
is_neck_eye[
flame_2019["weights"].argmax(1) == flame_joints_name.index("R_Eye")
] = 1
for idx in vertex_idxs:
if is_neck_eye[idx]:
continue
expr_vertex_idx.append(idx)
expr_vertex_idx = np.array(expr_vertex_idx)
expr_vertex_idx = self.face_vertex_idx[expr_vertex_idx]
return expr_vertex_idx
def _init_joint_info(self):
"""Initialize joint information and part mappings"""
self.joint_num = 55 # Body (22) + Face (3) + Hands (30)
self.joints_name = SMPLX_JOINTS
self.root_joint_idx = self.joints_name.index("Pelvis")
self._init_joint_part_mappings()
def _load_vertex_indices(self):
"""Load vertex indices for different body parts"""
self.vertex_num = 10475
self.face_orig = self.layer["neutral"].faces.astype(np.int64)
self.is_cavity, self.face = self._add_cavity()
with open(
osp.join(self.human_model_path, "smplx", "MANO_SMPLX_vertex_ids.pkl"), "rb"
) as f:
hand_vertex_idx = pickle.load(f, encoding="latin1")
self.rhand_vertex_idx = hand_vertex_idx["right_hand"]
self.lhand_vertex_idx = hand_vertex_idx["left_hand"]
self.expr_vertex_idx = self._get_expr_vertex_idx()
def _init_joint_part_mappings(self):
"""Initialize mappings between joints and body parts"""
self.joint_part = {
"body": range(
self.joints_name.index("Pelvis"), self.joints_name.index("R_Wrist") + 1
),
"face": range(
self.joints_name.index("Jaw"), self.joints_name.index("R_Eye") + 1
),
"lhand": range(
self.joints_name.index("L_Index_1"),
self.joints_name.index("L_Thumb_3") + 1,
),
"rhand": range(
self.joints_name.index("R_Index_1"),
self.joints_name.index("R_Thumb_3") + 1,
),
"lower_body": [
self.joints_name.index("Pelvis"),
self.joints_name.index("R_Hip"),
self.joints_name.index("L_Hip"),
self.joints_name.index("R_Knee"),
self.joints_name.index("L_Knee"),
self.joints_name.index("R_Ankle"),
self.joints_name.index("L_Ankle"),
self.joints_name.index("R_Foot"),
self.joints_name.index("L_Foot"),
],
"upper_body": [
self.joints_name.index("Spine_1"),
self.joints_name.index("Spine_2"),
self.joints_name.index("Spine_3"),
self.joints_name.index("L_Collar"),
self.joints_name.index("R_Collar"),
self.joints_name.index("L_Shoulder"),
self.joints_name.index("R_Shoulder"),
self.joints_name.index("L_Elbow"),
self.joints_name.index("R_Elbow"),
self.joints_name.index("L_Wrist"),
self.joints_name.index("R_Wrist"),
],
}
def _init_smplx_layers(self):
"""Initialize SMPLX layers with appropriate parameters"""
layer_args = {
"create_global_orient": False,
"create_body_pose": False,
"create_left_hand_pose": False,
"create_right_hand_pose": False,
"create_jaw_pose": False,
"create_leye_pose": False,
"create_reye_pose": False,
"create_betas": False,
"create_expression": False,
"create_transl": False,
}
human_model_path = "./pretrained_models/human_model_files"
use_face_contour = True
self.layer = {
gender: smplx.create(
human_model_path,
"smplx",
gender=gender,
num_betas=10,
num_expression_coeffs=100,
use_pca=False,
use_face_contour=use_face_contour,
flat_hand_mean=True,
**layer_args,
)
for gender in ["neutral"]
}
self.face_vertex_idx = np.load(
osp.join(self.human_model_path, "smplx", "SMPL-X__FLAME_vertex_ids.npy")
)
if use_face_contour:
print("Using FLAME expression")
self.layer = {
gender: self._get_expr_from_flame(self.layer[gender])
for gender in ["neutral"]
}
else:
print("Using basic SMPLX without FLAME expression")
def _get_expr_from_flame(self, smplx_layer):
"""Load expression parameters from FLAME model"""
flame_layer = smplx.create(
self.human_model_path,
"flame",
gender="neutral",
num_betas=self.shape_param_dim,
num_expression_coeffs=self.expr_param_dim,
)
smplx_layer.expr_dirs[self.face_vertex_idx, :, :] = flame_layer.expr_dirs
return smplx_layer
def _init_neutral_poses(self):
"""Initialize neutral pose configurations"""
body_joints = len(self.joint_part["body"]) - 1
self.neutral_body_pose = torch.zeros((body_joints, 3))
angle = math.pi / 6
self.neutral_body_pose[15] = torch.FloatTensor([0, 0, -angle])
self.neutral_body_pose[16] = torch.FloatTensor([0, 0, angle])
@torch.no_grad()
def __call__(self, smplx_data, device="cpu"):
shape_param = smplx_data["betas"]
batch_size = shape_param.shape[0]
zero_pose = torch.zeros((batch_size, 3)).float().to(device)
neutral_body_pose = (
self.neutral_body_pose.view(1, -1).repeat(batch_size, 1).to(device)
) # neutral pose
zero_hand_pose = (
torch.zeros((batch_size, len(self.joint_part["lhand"]) * 3))
.float()
.to(device)
)
zero_expr = torch.zeros((batch_size, self.expr_param_dim)).float().to(device)
jaw_pose = torch.zeros((batch_size, 3)).float().to(device)
shape_param = shape_param
face_offset = None
joint_offset = None
output = self.smplx_layer(
global_orient=zero_pose,
body_pose=neutral_body_pose,
left_hand_pose=zero_hand_pose,
right_hand_pose=zero_hand_pose,
jaw_pose=jaw_pose,
leye_pose=zero_pose,
reye_pose=zero_pose,
expression=zero_expr,
betas=shape_param,
face_offset=face_offset,
joint_offset=joint_offset,
)
vertices = output.vertices.squeeze(0)
faces = self.smplx_layer.faces
print(vertices.shape)
mesh = trimesh.Trimesh(vertices.detach().cpu().numpy(), faces, process=False)
return mesh
def basename(path):
pre_name = os.path.basename(path).split(".")[0]
return pre_name
def multi_process(worker, items, **kwargs):
"""
worker worker function to process items
"""
nodes = kwargs["nodes"]
dirs = kwargs["dirs"]
bucket = int(
np.ceil(len(items) / nodes)
) # avoid last node process too many items.
print("Total Nodes:", nodes)
print("Save Path:", dirs)
rank = int(os.environ.get("RANK", 0))
print("Current Rank:", rank)
kwargs["RANK"] = rank
if rank == nodes - 1:
output_dir = worker(items[bucket * rank :], **kwargs)
else:
output_dir = worker(items[bucket * rank : bucket * (rank + 1)], **kwargs)
if rank == 0 and nodes > 1:
timesleep = kwargs.get("timesleep", 3600)
time.sleep(timesleep) # one hour
def run_sampling(items, **params):
"""run pc sampling."""
output_dir = params["dirs"]
debug = params["debug"]
smplx_hand_vertex = params["smplx_hand_vertex"]
sampling_counts = 80000
os.makedirs(output_dir, exist_ok=True)
if debug:
items = items[:100]
process_valid = []
for item in tqdm(items, desc="Processing..."):
print(item)
if ".ply" not in item:
continue
gs_ply = item
baseid = basename(gs_ply)
mesh_root = f"./exps/smplx_output/{baseid}.obj"
mvs_recon_root = f"./exps/mvs_recon/{baseid}.obj"
smplx_mesh = trimesh.load_mesh(mesh_root, process=False)
vertices = smplx_mesh.vertices
hand_vertices = torch.from_numpy(
np.asarray(vertices[smplx_hand_vertex])
).float()
mesh = Mesh.load_obj(mvs_recon_root, device="cpu")
pts = mesh.sample_surface(sampling_counts)
sampling_pts = torch.cat([pts, hand_vertices], dim=0)
save_ply_path = os.path.join(output_dir, f"{baseid}_{sampling_counts}.ply")
save_ply(save_ply_path, sampling_pts)
return output_dir
def get_parse():
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument("-i", "--input", required=True, help="input path")
parser.add_argument("-o", "--output", required=True, help="output path")
parser.add_argument("--nodes", default=1, type=int, help="how many workload?")
parser.add_argument("--debug", action="store_true", help="debug tag")
parser.add_argument("--txt", default=None, type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
opt = get_parse()
smplx_model = EasySMPLX("./pretrained_models/human_model_files", 10, 100)
rhand_vertex_id = smplx_model.rhand_vertex_idx.tolist()
lhand_vertex_id = smplx_model.lhand_vertex_idx.tolist()
smplx_hand_vertex = rhand_vertex_id + lhand_vertex_id
# catch avaliable items
if opt.txt == None:
available_items = os.listdir(opt.input)
available_items = [os.path.join(opt.input, item) for item in available_items]
else:
available_items = []
with open(opt.txt) as reader:
for line in reader:
available_items.append(line.strip())
available_items = [
os.path.join(opt.input, item) for item in available_items
]
multi_process(
worker=run_sampling,
items=available_items,
dirs=opt.output,
nodes=opt.nodes,
debug=opt.debug,
smplx_hand_vertex=smplx_hand_vertex,
)