MagicArt / inference.py
ckc99u's picture
Upload 3 files
d038f33 verified
import os
import torch
import trimesh
import numpy as np
from pathlib import Path
import time
# Import from MagicArticulate
from skeleton_models.skeletongen import SkeletonGPT
from data_utils.save_npz import normalize_to_unit_cube
from utils.mesh_to_pc import MeshProcessor
from utils.save_utils import (
pred_joints_and_bones,
save_skeleton_to_txt,
merge_duplicate_joints_and_fix_bones,
save_skeleton_obj,
save_mesh
)
class SkeletonInferencer:
"""Wrapper class for skeleton generation inference"""
def __init__(self, pretrained_weights, device="cuda", precision="fp16"):
self.device = device
self.precision = precision
# Create args object
class Args:
def __init__(self):
self.llm = "facebook/opt-350m"
self.pad_id = -1
self.n_discrete_size = 128
self.n_max_bones = 100
self.num_beams = 1
self.seed = 0
self.args = Args()
# Load model
print(f"Loading model from {pretrained_weights}...")
self.model = SkeletonGPT(self.args).to(device)
pkg = torch.load(pretrained_weights, map_location=torch.device("cpu"))
self.model.load_state_dict(pkg["model"])
self.model.eval()
# Set precision
if precision == "fp16" and device == "cuda":
self.model = self.model.half()
print("Model loaded successfully!")
@torch.no_grad()
def infer(
self,
input_path,
output_dir,
input_pc_num=8192,
apply_marching_cubes=False,
octree_depth=7,
sequence_type="spatial"
):
"""
Run inference on a single mesh file
Returns:
dict: Results including paths and statistics
"""
start_time = time.time()
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load mesh
mesh = trimesh.load(input_path, force='mesh')
# Convert to point cloud
if apply_marching_cubes:
pc_list = MeshProcessor.convert_meshes_to_point_clouds(
[mesh], input_pc_num,
apply_marching_cubes=True,
octree_depth=octree_depth
)
pc_normal = pc_list[0]
else:
# Simple sampling
points, face_indices = trimesh.sample.sample_surface(mesh, input_pc_num)
normals = mesh.face_normals[face_indices]
pc_normal = np.concatenate([points, normals], axis=-1)
# Normalize point cloud
pc_coor = pc_normal[:, :3]
normals = pc_normal[:, 3:]
pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
# Prepare transform parameters
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
pc_center = (bounds[0] + bounds[1]) / 2
pc_scale = (bounds[1] - bounds[0]).max() + 1e-5
transform_params = torch.tensor([
center[0], center[1], center[2], scale,
pc_center[0], pc_center[1], pc_center[2], pc_scale
], dtype=torch.float32)
# Prepare batch data
pc_normal_normalized = np.concatenate([pc_coor, normals], axis=-1)
batch_data = {
'pc_normal': torch.from_numpy(pc_normal_normalized).half().unsqueeze(0).to(self.device),
'transform_params': transform_params.unsqueeze(0),
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0),
'faces': torch.from_numpy(mesh.faces).unsqueeze(0),
'file_name': [Path(input_path).stem]
}
# Generate skeleton
pred_bone_coords = self.model.generate(batch_data)
# Process results
file_name = Path(input_path).stem
skeleton = pred_bone_coords[0].cpu().numpy()
pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
# Post-process
hier_order = (sequence_type == "hierarchical")
if hier_order and len(pred_bones) > 0:
pred_root_index = pred_bones[0][0]
pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(
pred_joints, pred_bones, root_index=pred_root_index
)
else:
pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(
pred_joints, pred_bones
)
pred_root_index = None
# Denormalize for saving
trans = transform_params[:3].numpy()
scale_val = transform_params[3].item()
pc_trans = transform_params[4:7].numpy()
pc_scale_val = transform_params[7].item()
pred_joints_denorm = pred_joints * pc_scale_val + pc_trans
pred_joints_denorm = pred_joints_denorm / scale_val + trans
# Save files
pred_rig_filename = output_dir / f"{file_name}_pred.txt"
pred_skel_filename = output_dir / f"{file_name}_skel.obj"
mesh_filename = output_dir / f"{file_name}_mesh.obj"
save_skeleton_to_txt(
pred_joints_denorm, pred_bones, pred_root_index,
hier_order, mesh.vertices, str(pred_rig_filename)
)
save_skeleton_obj(
pred_joints, pred_bones, str(pred_skel_filename),
pred_root_index if hier_order else None,
use_cone=hier_order
)
# Save normalized mesh
vertices_norm = (mesh.vertices - trans) * scale_val
vertices_norm = (vertices_norm - pc_trans) / pc_scale_val
save_mesh(vertices_norm, mesh.faces, str(mesh_filename))
elapsed_time = time.time() - start_time
return {
'skeleton_file': str(pred_skel_filename),
'rig_file': str(pred_rig_filename),
'mesh_file': str(mesh_filename),
'num_joints': len(pred_joints),
'num_bones': len(pred_bones),
'time': elapsed_time
}