|
|
import os
|
|
|
import torch
|
|
|
import trimesh
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
import time
|
|
|
|
|
|
|
|
|
from skeleton_models.skeletongen import SkeletonGPT
|
|
|
from data_utils.save_npz import normalize_to_unit_cube
|
|
|
from utils.mesh_to_pc import MeshProcessor
|
|
|
from utils.save_utils import (
|
|
|
pred_joints_and_bones,
|
|
|
save_skeleton_to_txt,
|
|
|
merge_duplicate_joints_and_fix_bones,
|
|
|
save_skeleton_obj,
|
|
|
save_mesh
|
|
|
)
|
|
|
|
|
|
class SkeletonInferencer:
|
|
|
"""Wrapper class for skeleton generation inference"""
|
|
|
|
|
|
def __init__(self, pretrained_weights, device="cuda", precision="fp16"):
|
|
|
self.device = device
|
|
|
self.precision = precision
|
|
|
|
|
|
|
|
|
class Args:
|
|
|
def __init__(self):
|
|
|
self.llm = "facebook/opt-350m"
|
|
|
self.pad_id = -1
|
|
|
self.n_discrete_size = 128
|
|
|
self.n_max_bones = 100
|
|
|
self.num_beams = 1
|
|
|
self.seed = 0
|
|
|
|
|
|
self.args = Args()
|
|
|
|
|
|
|
|
|
print(f"Loading model from {pretrained_weights}...")
|
|
|
self.model = SkeletonGPT(self.args).to(device)
|
|
|
|
|
|
pkg = torch.load(pretrained_weights, map_location=torch.device("cpu"))
|
|
|
self.model.load_state_dict(pkg["model"])
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
if precision == "fp16" and device == "cuda":
|
|
|
self.model = self.model.half()
|
|
|
|
|
|
print("Model loaded successfully!")
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def infer(
|
|
|
self,
|
|
|
input_path,
|
|
|
output_dir,
|
|
|
input_pc_num=8192,
|
|
|
apply_marching_cubes=False,
|
|
|
octree_depth=7,
|
|
|
sequence_type="spatial"
|
|
|
):
|
|
|
"""
|
|
|
Run inference on a single mesh file
|
|
|
|
|
|
Returns:
|
|
|
dict: Results including paths and statistics
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
output_dir = Path(output_dir)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
mesh = trimesh.load(input_path, force='mesh')
|
|
|
|
|
|
|
|
|
if apply_marching_cubes:
|
|
|
pc_list = MeshProcessor.convert_meshes_to_point_clouds(
|
|
|
[mesh], input_pc_num,
|
|
|
apply_marching_cubes=True,
|
|
|
octree_depth=octree_depth
|
|
|
)
|
|
|
pc_normal = pc_list[0]
|
|
|
else:
|
|
|
|
|
|
points, face_indices = trimesh.sample.sample_surface(mesh, input_pc_num)
|
|
|
normals = mesh.face_normals[face_indices]
|
|
|
pc_normal = np.concatenate([points, normals], axis=-1)
|
|
|
|
|
|
|
|
|
pc_coor = pc_normal[:, :3]
|
|
|
normals = pc_normal[:, 3:]
|
|
|
pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
|
|
|
|
|
|
|
|
|
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
|
|
|
pc_center = (bounds[0] + bounds[1]) / 2
|
|
|
pc_scale = (bounds[1] - bounds[0]).max() + 1e-5
|
|
|
|
|
|
transform_params = torch.tensor([
|
|
|
center[0], center[1], center[2], scale,
|
|
|
pc_center[0], pc_center[1], pc_center[2], pc_scale
|
|
|
], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
pc_normal_normalized = np.concatenate([pc_coor, normals], axis=-1)
|
|
|
batch_data = {
|
|
|
'pc_normal': torch.from_numpy(pc_normal_normalized).half().unsqueeze(0).to(self.device),
|
|
|
'transform_params': transform_params.unsqueeze(0),
|
|
|
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0),
|
|
|
'faces': torch.from_numpy(mesh.faces).unsqueeze(0),
|
|
|
'file_name': [Path(input_path).stem]
|
|
|
}
|
|
|
|
|
|
|
|
|
pred_bone_coords = self.model.generate(batch_data)
|
|
|
|
|
|
|
|
|
file_name = Path(input_path).stem
|
|
|
skeleton = pred_bone_coords[0].cpu().numpy()
|
|
|
pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
|
|
|
|
|
|
|
|
|
hier_order = (sequence_type == "hierarchical")
|
|
|
if hier_order and len(pred_bones) > 0:
|
|
|
pred_root_index = pred_bones[0][0]
|
|
|
pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(
|
|
|
pred_joints, pred_bones, root_index=pred_root_index
|
|
|
)
|
|
|
else:
|
|
|
pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(
|
|
|
pred_joints, pred_bones
|
|
|
)
|
|
|
pred_root_index = None
|
|
|
|
|
|
|
|
|
trans = transform_params[:3].numpy()
|
|
|
scale_val = transform_params[3].item()
|
|
|
pc_trans = transform_params[4:7].numpy()
|
|
|
pc_scale_val = transform_params[7].item()
|
|
|
|
|
|
pred_joints_denorm = pred_joints * pc_scale_val + pc_trans
|
|
|
pred_joints_denorm = pred_joints_denorm / scale_val + trans
|
|
|
|
|
|
|
|
|
pred_rig_filename = output_dir / f"{file_name}_pred.txt"
|
|
|
pred_skel_filename = output_dir / f"{file_name}_skel.obj"
|
|
|
mesh_filename = output_dir / f"{file_name}_mesh.obj"
|
|
|
|
|
|
save_skeleton_to_txt(
|
|
|
pred_joints_denorm, pred_bones, pred_root_index,
|
|
|
hier_order, mesh.vertices, str(pred_rig_filename)
|
|
|
)
|
|
|
|
|
|
save_skeleton_obj(
|
|
|
pred_joints, pred_bones, str(pred_skel_filename),
|
|
|
pred_root_index if hier_order else None,
|
|
|
use_cone=hier_order
|
|
|
)
|
|
|
|
|
|
|
|
|
vertices_norm = (mesh.vertices - trans) * scale_val
|
|
|
vertices_norm = (vertices_norm - pc_trans) / pc_scale_val
|
|
|
save_mesh(vertices_norm, mesh.faces, str(mesh_filename))
|
|
|
|
|
|
elapsed_time = time.time() - start_time
|
|
|
|
|
|
return {
|
|
|
'skeleton_file': str(pred_skel_filename),
|
|
|
'rig_file': str(pred_rig_filename),
|
|
|
'mesh_file': str(mesh_filename),
|
|
|
'num_joints': len(pred_joints),
|
|
|
'num_bones': len(pred_bones),
|
|
|
'time': elapsed_time
|
|
|
}
|
|
|
|