import os import torch import trimesh import numpy as np from pathlib import Path import time # Import from MagicArticulate from skeleton_models.skeletongen import SkeletonGPT from data_utils.save_npz import normalize_to_unit_cube from utils.mesh_to_pc import MeshProcessor from utils.save_utils import ( pred_joints_and_bones, save_skeleton_to_txt, merge_duplicate_joints_and_fix_bones, save_skeleton_obj, save_mesh ) class SkeletonInferencer: """Wrapper class for skeleton generation inference""" def __init__(self, pretrained_weights, device="cuda", precision="fp16"): self.device = device self.precision = precision # Create args object class Args: def __init__(self): self.llm = "facebook/opt-350m" self.pad_id = -1 self.n_discrete_size = 128 self.n_max_bones = 100 self.num_beams = 1 self.seed = 0 self.args = Args() # Load model print(f"Loading model from {pretrained_weights}...") self.model = SkeletonGPT(self.args).to(device) pkg = torch.load(pretrained_weights, map_location=torch.device("cpu")) self.model.load_state_dict(pkg["model"]) self.model.eval() # Set precision if precision == "fp16" and device == "cuda": self.model = self.model.half() print("Model loaded successfully!") @torch.no_grad() def infer( self, input_path, output_dir, input_pc_num=8192, apply_marching_cubes=False, octree_depth=7, sequence_type="spatial" ): """ Run inference on a single mesh file Returns: dict: Results including paths and statistics """ start_time = time.time() output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Load mesh mesh = trimesh.load(input_path, force='mesh') # Convert to point cloud if apply_marching_cubes: pc_list = MeshProcessor.convert_meshes_to_point_clouds( [mesh], input_pc_num, apply_marching_cubes=True, octree_depth=octree_depth ) pc_normal = pc_list[0] else: # Simple sampling points, face_indices = trimesh.sample.sample_surface(mesh, input_pc_num) normals = mesh.face_normals[face_indices] pc_normal = np.concatenate([points, normals], axis=-1) # Normalize point cloud pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995) # Prepare transform parameters bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)]) pc_center = (bounds[0] + bounds[1]) / 2 pc_scale = (bounds[1] - bounds[0]).max() + 1e-5 transform_params = torch.tensor([ center[0], center[1], center[2], scale, pc_center[0], pc_center[1], pc_center[2], pc_scale ], dtype=torch.float32) # Prepare batch data pc_normal_normalized = np.concatenate([pc_coor, normals], axis=-1) batch_data = { 'pc_normal': torch.from_numpy(pc_normal_normalized).half().unsqueeze(0).to(self.device), 'transform_params': transform_params.unsqueeze(0), 'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0), 'faces': torch.from_numpy(mesh.faces).unsqueeze(0), 'file_name': [Path(input_path).stem] } # Generate skeleton pred_bone_coords = self.model.generate(batch_data) # Process results file_name = Path(input_path).stem skeleton = pred_bone_coords[0].cpu().numpy() pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze()) # Post-process hier_order = (sequence_type == "hierarchical") if hier_order and len(pred_bones) > 0: pred_root_index = pred_bones[0][0] pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones( pred_joints, pred_bones, root_index=pred_root_index ) else: pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones( pred_joints, pred_bones ) pred_root_index = None # Denormalize for saving trans = transform_params[:3].numpy() scale_val = transform_params[3].item() pc_trans = transform_params[4:7].numpy() pc_scale_val = transform_params[7].item() pred_joints_denorm = pred_joints * pc_scale_val + pc_trans pred_joints_denorm = pred_joints_denorm / scale_val + trans # Save files pred_rig_filename = output_dir / f"{file_name}_pred.txt" pred_skel_filename = output_dir / f"{file_name}_skel.obj" mesh_filename = output_dir / f"{file_name}_mesh.obj" save_skeleton_to_txt( pred_joints_denorm, pred_bones, pred_root_index, hier_order, mesh.vertices, str(pred_rig_filename) ) save_skeleton_obj( pred_joints, pred_bones, str(pred_skel_filename), pred_root_index if hier_order else None, use_cone=hier_order ) # Save normalized mesh vertices_norm = (mesh.vertices - trans) * scale_val vertices_norm = (vertices_norm - pc_trans) / pc_scale_val save_mesh(vertices_norm, mesh.faces, str(mesh_filename)) elapsed_time = time.time() - start_time return { 'skeleton_file': str(pred_skel_filename), 'rig_file': str(pred_rig_filename), 'mesh_file': str(mesh_filename), 'num_joints': len(pred_joints), 'num_bones': len(pred_bones), 'time': elapsed_time }