Upload 33 files
Browse files- .gitattributes +40 -35
- README.md +13 -13
- app.py +326 -0
- assets/MagicArticulate_teaser.gif +3 -0
- assets/ar_demo.gif +3 -0
- assets/articulation-xl2.0.png +3 -0
- assets/data_statistics.png +0 -0
- assets/sequence_ordering_demo.gif +3 -0
- assets/skeleton_compare.png +0 -0
- data_utils/README.md +43 -0
- data_utils/clean_skin_in_npz.py +95 -0
- data_utils/convert_npz_to_mesh_rig.py +107 -0
- data_utils/data_loader.py +121 -0
- data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.obj +0 -0
- data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.txt +0 -0
- data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png +3 -0
- data_utils/issue_data_list.txt +123 -0
- data_utils/pyrender_wrapper.py +135 -0
- data_utils/read_npz.py +43 -0
- data_utils/read_rig_mesh_from_glb.py +198 -0
- data_utils/render_data.py +61 -0
- data_utils/save_npz.py +256 -0
- data_utils/update_npz_rm_issue_data.py +59 -0
- demo.py +214 -0
- demo.sh +4 -0
- download.py +19 -0
- requirements.txt +37 -0
- skeleton_models/shape_opt.py +406 -0
- skeleton_models/skeletongen.py +198 -0
- utils/eval_utils.py +57 -0
- utils/mesh_to_pc.py +84 -0
- utils/save_utils.py +578 -0
- utils/skeleton_data_loader.py +97 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,40 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/ar_demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/articulation-xl2.0.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/MagicArticulate_teaser.gif filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/sequence_ordering_demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: MagicArt
|
| 3 |
-
emoji: 🏆
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: pink
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 6.0.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
short_description: obj to rig test
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MagicArt
|
| 3 |
+
emoji: 🏆
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.0.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: obj to rig test
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import trimesh
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import tempfile
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
from skeleton_models.skeletongen import SkeletonGPT
|
| 11 |
+
from data_utils.save_npz import normalize_to_unit_cube
|
| 12 |
+
from utils.mesh_to_pc import MeshProcessor
|
| 13 |
+
from utils.save_utils import (
|
| 14 |
+
pred_joints_and_bones,
|
| 15 |
+
save_skeleton_to_txt,
|
| 16 |
+
merge_duplicate_joints_and_fix_bones,
|
| 17 |
+
save_skeleton_obj,
|
| 18 |
+
save_mesh
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Global model variable
|
| 22 |
+
model = None
|
| 23 |
+
args_config = None
|
| 24 |
+
|
| 25 |
+
def initialize_model():
|
| 26 |
+
"""Initialize the model once at startup"""
|
| 27 |
+
global model, args_config
|
| 28 |
+
|
| 29 |
+
if model is not None:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
print("Initializing MagicArticulate model...")
|
| 33 |
+
|
| 34 |
+
# Create a simple args object with default parameters
|
| 35 |
+
class Args:
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.input_pc_num = 8192
|
| 38 |
+
self.num_beams = 1
|
| 39 |
+
self.llm = "facebook/opt-350m"
|
| 40 |
+
self.pad_id = -1
|
| 41 |
+
self.n_discrete_size = 128
|
| 42 |
+
self.n_max_bones = 100
|
| 43 |
+
self.seed = 0
|
| 44 |
+
self.precision = "fp16"
|
| 45 |
+
self.pretrained_weights = "checkpoints/checkpoint_trainonv2_hier.pt" # Default checkpoint
|
| 46 |
+
self.hier_order = False
|
| 47 |
+
|
| 48 |
+
args_config = Args()
|
| 49 |
+
|
| 50 |
+
# Load model
|
| 51 |
+
model = SkeletonGPT(args_config).cuda()
|
| 52 |
+
|
| 53 |
+
# Load pretrained weights
|
| 54 |
+
if os.path.exists(args_config.pretrained_weights):
|
| 55 |
+
pkg = torch.load(args_config.pretrained_weights, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 56 |
+
model.load_state_dict(pkg["model"])
|
| 57 |
+
model.eval()
|
| 58 |
+
print("Model loaded successfully!")
|
| 59 |
+
else:
|
| 60 |
+
print(f"Warning: Pretrained weights not found at {args_config.pretrained_weights}")
|
| 61 |
+
raise FileNotFoundError("Model checkpoint not found. Please ensure checkpoints are downloaded.")
|
| 62 |
+
|
| 63 |
+
def process_mesh(
|
| 64 |
+
input_file,
|
| 65 |
+
apply_marching_cubes,
|
| 66 |
+
hier_order,
|
| 67 |
+
octree_depth
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Process the input mesh and generate rigging prediction
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
input_file: Uploaded mesh file (.obj, .ply, or .stl)
|
| 74 |
+
apply_marching_cubes: Whether to apply marching cubes
|
| 75 |
+
hier_order: Whether to use hierarchical ordering
|
| 76 |
+
octree_depth: Depth for octree (if using marching cubes)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Tuple of (skeleton obj file, rig txt file, normalized mesh file, status message)
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
# Initialize model if not already done
|
| 83 |
+
if model is None:
|
| 84 |
+
initialize_model()
|
| 85 |
+
|
| 86 |
+
# Create temporary output directory
|
| 87 |
+
output_dir = tempfile.mkdtemp()
|
| 88 |
+
|
| 89 |
+
# Get file information
|
| 90 |
+
file_name = Path(input_file).stem
|
| 91 |
+
file_ext = Path(input_file).suffix.lower()
|
| 92 |
+
|
| 93 |
+
# Check file type
|
| 94 |
+
if file_ext not in ['.obj', '.ply', '.stl']:
|
| 95 |
+
return None, None, None, f"Error: Unsupported file type {file_ext}. Please upload .obj, .ply, or .stl file."
|
| 96 |
+
|
| 97 |
+
# Load mesh
|
| 98 |
+
mesh = trimesh.load(input_file, force='mesh')
|
| 99 |
+
|
| 100 |
+
# Convert mesh to point cloud
|
| 101 |
+
print(f"Converting mesh to point cloud (apply_marching_cubes={apply_marching_cubes})...")
|
| 102 |
+
pc_list = MeshProcessor.convert_meshes_to_point_clouds(
|
| 103 |
+
[mesh],
|
| 104 |
+
args_config.input_pc_num,
|
| 105 |
+
apply_marching_cubes=apply_marching_cubes,
|
| 106 |
+
octree_depth=octree_depth
|
| 107 |
+
)
|
| 108 |
+
pc_normal = pc_list[0]
|
| 109 |
+
|
| 110 |
+
# Normalize point cloud
|
| 111 |
+
pc_coor = pc_normal[:, :3]
|
| 112 |
+
normals = pc_normal[:, 3:]
|
| 113 |
+
pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
|
| 114 |
+
|
| 115 |
+
pc_coor = pc_coor.astype(np.float32)
|
| 116 |
+
normals = normals.astype(np.float32)
|
| 117 |
+
|
| 118 |
+
# Calculate transform parameters
|
| 119 |
+
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
|
| 120 |
+
pc_center = (bounds[0] + bounds[1])[None, :] / 2
|
| 121 |
+
pc_scale = ((bounds[1] - bounds[0]).max() + 1e-5)
|
| 122 |
+
|
| 123 |
+
transform_params = torch.tensor([
|
| 124 |
+
center[0], center[1], center[2],
|
| 125 |
+
scale,
|
| 126 |
+
pc_center[0][0], pc_center[0][1], pc_center[0][2],
|
| 127 |
+
pc_scale
|
| 128 |
+
], dtype=torch.float32)
|
| 129 |
+
|
| 130 |
+
# Prepare batch data
|
| 131 |
+
pc_normal_tensor = torch.from_numpy(
|
| 132 |
+
np.concatenate([pc_coor, normals], axis=-1).astype(np.float16)
|
| 133 |
+
).unsqueeze(0).cuda()
|
| 134 |
+
|
| 135 |
+
batch_data = {
|
| 136 |
+
'pc_normal': pc_normal_tensor,
|
| 137 |
+
'file_name': [file_name],
|
| 138 |
+
'transform_params': transform_params.unsqueeze(0).cuda(),
|
| 139 |
+
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0).cuda(),
|
| 140 |
+
'faces': torch.from_numpy(mesh.faces).unsqueeze(0).cuda()
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Generate skeleton
|
| 144 |
+
print("Generating skeleton...")
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
pred_bone_coords = model.generate(batch_data)
|
| 147 |
+
|
| 148 |
+
# Process predictions
|
| 149 |
+
skeleton = pred_bone_coords[0].cpu().numpy()
|
| 150 |
+
pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
|
| 151 |
+
|
| 152 |
+
# Post-process: merge duplicate joints
|
| 153 |
+
if hier_order:
|
| 154 |
+
pred_root_index = pred_bones[0][0]
|
| 155 |
+
pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(
|
| 156 |
+
pred_joints, pred_bones, root_index=pred_root_index
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones)
|
| 160 |
+
pred_root_index = None
|
| 161 |
+
|
| 162 |
+
# Denormalize joints for rig file
|
| 163 |
+
transform_params_np = transform_params.cpu().numpy()
|
| 164 |
+
trans = transform_params_np[:3]
|
| 165 |
+
scale_val = transform_params_np[3]
|
| 166 |
+
pc_trans = transform_params_np[4:7]
|
| 167 |
+
pc_scale_val = transform_params_np[7]
|
| 168 |
+
|
| 169 |
+
pred_joints_denorm = pred_joints * pc_scale_val + pc_trans
|
| 170 |
+
pred_joints_denorm = pred_joints_denorm / scale_val + trans
|
| 171 |
+
|
| 172 |
+
# Save outputs
|
| 173 |
+
skel_obj_path = os.path.join(output_dir, f'{file_name}_skel.obj')
|
| 174 |
+
rig_txt_path = os.path.join(output_dir, f'{file_name}_pred.txt')
|
| 175 |
+
mesh_obj_path = os.path.join(output_dir, f'{file_name}_mesh.obj')
|
| 176 |
+
|
| 177 |
+
# Save skeleton
|
| 178 |
+
save_skeleton_obj(
|
| 179 |
+
pred_joints,
|
| 180 |
+
pred_bones,
|
| 181 |
+
skel_obj_path,
|
| 182 |
+
pred_root_index if hier_order else None,
|
| 183 |
+
use_cone=hier_order
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Save rig
|
| 187 |
+
vertices_np = mesh.vertices
|
| 188 |
+
save_skeleton_to_txt(
|
| 189 |
+
pred_joints_denorm,
|
| 190 |
+
pred_bones,
|
| 191 |
+
pred_root_index,
|
| 192 |
+
hier_order,
|
| 193 |
+
vertices_np,
|
| 194 |
+
rig_txt_path
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Save normalized mesh
|
| 198 |
+
vertices_norm = (vertices_np - trans) * scale_val
|
| 199 |
+
vertices_norm = (vertices_norm - pc_trans) / pc_scale_val
|
| 200 |
+
save_mesh(vertices_norm, mesh.faces, mesh_obj_path)
|
| 201 |
+
|
| 202 |
+
status_msg = f"✅ Success! Generated skeleton with {len(pred_joints)} joints and {len(pred_bones)} bones."
|
| 203 |
+
|
| 204 |
+
return skel_obj_path, rig_txt_path, mesh_obj_path, status_msg
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
import traceback
|
| 208 |
+
error_msg = f"❌ Error processing mesh: {str(e)}\n{traceback.format_exc()}"
|
| 209 |
+
print(error_msg)
|
| 210 |
+
return None, None, None, error_msg
|
| 211 |
+
|
| 212 |
+
# Create Gradio interface
|
| 213 |
+
def create_interface():
|
| 214 |
+
"""Create the Gradio interface"""
|
| 215 |
+
|
| 216 |
+
with gr.Blocks(title="MagicArticulate - 3D Model Rigging") as demo:
|
| 217 |
+
gr.Markdown("""
|
| 218 |
+
# 🪄 MagicArticulate: Make Your 3D Models Articulation-Ready
|
| 219 |
+
|
| 220 |
+
Upload a 3D mesh (.obj, .ply, or .stl) to automatically generate skeletal rigging.
|
| 221 |
+
|
| 222 |
+
**Paper**: [CVPR 2025] MagicArticulate ([Project Page](https://chaoyuesong.github.io/MagicArticulate/))
|
| 223 |
+
""")
|
| 224 |
+
|
| 225 |
+
with gr.Row():
|
| 226 |
+
with gr.Column(scale=1):
|
| 227 |
+
gr.Markdown("### Input")
|
| 228 |
+
input_file = gr.File(
|
| 229 |
+
label="Upload 3D Mesh",
|
| 230 |
+
file_types=[".obj", ".ply", ".stl"],
|
| 231 |
+
type="filepath"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
gr.Markdown("### Options")
|
| 235 |
+
apply_marching_cubes = gr.Checkbox(
|
| 236 |
+
label="Apply Marching Cubes",
|
| 237 |
+
value=False,
|
| 238 |
+
info="Apply marching cubes for mesh processing (slower but more accurate)"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
hier_order = gr.Checkbox(
|
| 242 |
+
label="Hierarchical Ordering",
|
| 243 |
+
value=False,
|
| 244 |
+
info="Use hierarchical sequence ordering for skeleton generation"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
octree_depth = gr.Slider(
|
| 248 |
+
minimum=5,
|
| 249 |
+
maximum=9,
|
| 250 |
+
value=7,
|
| 251 |
+
step=1,
|
| 252 |
+
label="Octree Depth",
|
| 253 |
+
info="Depth for octree (only used if Marching Cubes is enabled)"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
generate_btn = gr.Button("🚀 Generate Rigging", variant="primary", size="lg")
|
| 257 |
+
|
| 258 |
+
with gr.Column(scale=1):
|
| 259 |
+
gr.Markdown("### Output")
|
| 260 |
+
status_text = gr.Textbox(
|
| 261 |
+
label="Status",
|
| 262 |
+
lines=3,
|
| 263 |
+
interactive=False
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
skel_output = gr.File(
|
| 267 |
+
label="📥 Skeleton (.obj)",
|
| 268 |
+
interactive=False
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
rig_output = gr.File(
|
| 272 |
+
label="📥 Rig Prediction (.txt)",
|
| 273 |
+
interactive=False
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
mesh_output = gr.File(
|
| 277 |
+
label="📥 Normalized Mesh (.obj)",
|
| 278 |
+
interactive=False
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
gr.Markdown("""
|
| 282 |
+
### About
|
| 283 |
+
MagicArticulate automatically generates skeletal structures for 3D models, making them ready for animation.
|
| 284 |
+
The system predicts joint positions and bone connections using a transformer-based approach.
|
| 285 |
+
|
| 286 |
+
**Outputs**:
|
| 287 |
+
- **Skeleton (.obj)**: 3D visualization of the generated skeleton
|
| 288 |
+
- **Rig Prediction (.txt)**: Detailed rigging information (joints, bones, hierarchy)
|
| 289 |
+
- **Normalized Mesh (.obj)**: The input mesh normalized to unit cube
|
| 290 |
+
|
| 291 |
+
**Citation**:
|
| 292 |
+
```
|
| 293 |
+
@inproceedings{song2025magicarticulate,
|
| 294 |
+
title={MagicArticulate: Make Your 3D Models Articulation-Ready},
|
| 295 |
+
author={Song, Chaoyue and others},
|
| 296 |
+
booktitle={CVPR},
|
| 297 |
+
year={2025}
|
| 298 |
+
}
|
| 299 |
+
```
|
| 300 |
+
""")
|
| 301 |
+
|
| 302 |
+
# Connect the button to the processing function
|
| 303 |
+
generate_btn.click(
|
| 304 |
+
fn=process_mesh,
|
| 305 |
+
inputs=[input_file, apply_marching_cubes, hier_order, octree_depth],
|
| 306 |
+
outputs=[skel_output, rig_output, mesh_output, status_text]
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return demo
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
# Initialize model at startup
|
| 313 |
+
try:
|
| 314 |
+
initialize_model()
|
| 315 |
+
except Exception as e:
|
| 316 |
+
print(f"Warning: Could not initialize model at startup: {e}")
|
| 317 |
+
print("Model will be initialized on first request.")
|
| 318 |
+
|
| 319 |
+
# Launch Gradio app
|
| 320 |
+
demo = create_interface()
|
| 321 |
+
demo.queue()
|
| 322 |
+
demo.launch(
|
| 323 |
+
server_name="0.0.0.0",
|
| 324 |
+
server_port=7860,
|
| 325 |
+
share=False
|
| 326 |
+
)
|
assets/MagicArticulate_teaser.gif
ADDED
|
Git LFS Details
|
assets/ar_demo.gif
ADDED
|
Git LFS Details
|
assets/articulation-xl2.0.png
ADDED
|
Git LFS Details
|
assets/data_statistics.png
ADDED
|
assets/sequence_ordering_demo.gif
ADDED
|
Git LFS Details
|
assets/skeleton_compare.png
ADDED
|
data_utils/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Preprocessed data
|
| 2 |
+
We provide the preprocessed data that saved in NPZ files, which contain the following information:
|
| 3 |
+
```
|
| 4 |
+
'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'pc_w_norm', 'joint_names', 'skinning_weights_value', 'skinning_weights_rows', 'skinning_weights_cols', 'skinning_weights_shape'
|
| 5 |
+
```
|
| 6 |
+
You can check `read_npz.py` for how to read the NPZ files and `save_npz.py` for how we save them.
|
| 7 |
+
|
| 8 |
+
Before saving them into NPZ files, we extract mesh(.obj) and rig(.txt) from downloaded 3D models from Objaverse-XL using Blender. The rig file follows the format in [RigNet](https://github.com/zhan-xu/RigNet), which includes the following entries:
|
| 9 |
+
```
|
| 10 |
+
joints [joint_name] [x] [y] [z]
|
| 11 |
+
root [root_joint_name]
|
| 12 |
+
skin [vertex_index] [joints_name1] [skinning_weight1] [joints_name2] [skinning_weight2] ...
|
| 13 |
+
hier [parent_joint_name] [child_joint_name]
|
| 14 |
+
```
|
| 15 |
+
For an example, please see `examples/0a59c5ffa4a1476bac6d540b79947f31.txt`.
|
| 16 |
+
|
| 17 |
+
If you want to convert NPZ file back to OBJ and TXT files, we give an example by running:
|
| 18 |
+
```
|
| 19 |
+
python convert_npz_to_mesh_rig.py
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Visualization
|
| 23 |
+
We provide a method for visualizing 3D models with skeleton using [Pyrender](https://github.com/mmatl/pyrender), modified from [Lab4D](https://github.com/lab4d-org/lab4d/tree/ppr/). This visualization also serves as input to the VLM for skeleton quality rating. Make sure you have installed the following packages before running visualization:
|
| 24 |
+
```
|
| 25 |
+
pip install trimesh opencv-python pyrender
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
We provide an example to demonstrate the process. For this example, we prepare an OBJ file along with a TXT file containing rigging information. Then, run:
|
| 29 |
+
```
|
| 30 |
+
python render_data.py
|
| 31 |
+
```
|
| 32 |
+
You will obtain the following outputs:
|
| 33 |
+
|
| 34 |
+
<p align="center">
|
| 35 |
+
<img width="80%" src="examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png"/>
|
| 36 |
+
</p>
|
| 37 |
+
|
| 38 |
+
### Reading rig and mesh from GLBs
|
| 39 |
+
We provide the script we use for reading rig (.txt) and mesh (.obj) from glb files. You can run:
|
| 40 |
+
```
|
| 41 |
+
python read_rig_mesh_from_glb.py
|
| 42 |
+
```
|
| 43 |
+
Remember to download Blender (we use 4.2.0) and also bpy in your conda environment.
|
data_utils/clean_skin_in_npz.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import numpy as np
|
| 15 |
+
import scipy.sparse as sp
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
def check_and_clean_skinning_weights(file_path, output_path, tolerance=0.1):
|
| 19 |
+
"""
|
| 20 |
+
Check if all rows in pc_skinning_weights sum to 1 for each item in the NPZ file.
|
| 21 |
+
Remove invalid items and save a cleaned version.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
file_path: Path to the input NPZ file
|
| 25 |
+
output_path: Path for the cleaned NPZ file
|
| 26 |
+
tolerance: Tolerance for floating point comparison
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
tuple: (cleaned_data_list, removed_indices)
|
| 30 |
+
"""
|
| 31 |
+
data_list = np.load(file_path, allow_pickle=True)['arr_0']
|
| 32 |
+
|
| 33 |
+
invalid_indices = []
|
| 34 |
+
valid_data_list = []
|
| 35 |
+
|
| 36 |
+
for idx, data in enumerate(data_list):
|
| 37 |
+
is_valid = True
|
| 38 |
+
|
| 39 |
+
weights_data = data['skinning_weights_value']
|
| 40 |
+
weights_row = data['skinning_weights_row']
|
| 41 |
+
weights_col = data['skinning_weights_col']
|
| 42 |
+
weights_shape = data['skinning_weights_shape']
|
| 43 |
+
|
| 44 |
+
skinning_sparse = sp.coo_matrix(
|
| 45 |
+
(weights_data, (weights_row, weights_col)),
|
| 46 |
+
shape=weights_shape
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
skinning_csr = skinning_sparse.tocsr()
|
| 50 |
+
row_sums = np.array(skinning_csr.sum(axis=1)).flatten()
|
| 51 |
+
|
| 52 |
+
invalid_rows = np.where(np.abs(row_sums - 1.0) > tolerance)[0]
|
| 53 |
+
|
| 54 |
+
if len(invalid_rows) > 0:
|
| 55 |
+
min_sum = np.min(row_sums)
|
| 56 |
+
max_sum = np.max(row_sums)
|
| 57 |
+
invalid_indices.append((data['uuid'], f"{len(invalid_rows)} rows, range: [{min_sum:.6f}, {max_sum:.6f}]"))
|
| 58 |
+
is_valid = False
|
| 59 |
+
|
| 60 |
+
if is_valid:
|
| 61 |
+
valid_data_list.append(data)
|
| 62 |
+
|
| 63 |
+
# Save the cleaned data
|
| 64 |
+
if valid_data_list:
|
| 65 |
+
np.savez_compressed(output_path, valid_data_list, allow_pickle=True)
|
| 66 |
+
print(f"Saved {len(valid_data_list)} valid items to {output_path}")
|
| 67 |
+
|
| 68 |
+
return valid_data_list, invalid_indices
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
# File paths
|
| 72 |
+
file_path = "articulation_xlv2_train.npz" # "articulation_xlv2_test.npz"
|
| 73 |
+
log_file = "invalid_skinning_weights_intrain.txt" # "invalid_skinning_weights_intest.txt"
|
| 74 |
+
output_path = "articulation_xlv2_train_updated.npz" # "articulation_xlv2_test_updated.npz"
|
| 75 |
+
|
| 76 |
+
# Clean the data
|
| 77 |
+
valid_data, invalid_indices = check_and_clean_skinning_weights(file_path, output_path)
|
| 78 |
+
|
| 79 |
+
# Log the results
|
| 80 |
+
with open(log_file, "w") as f:
|
| 81 |
+
f.write(f"Original file: {file_path}\n")
|
| 82 |
+
f.write(f"Cleaned file: {output_path}\n")
|
| 83 |
+
f.write(f"Total items: {len(np.load(file_path, allow_pickle=True)['arr_0'])}\n")
|
| 84 |
+
f.write(f"Valid items: {len(valid_data)}\n")
|
| 85 |
+
f.write(f"Removed items: {len(invalid_indices)}\n\n")
|
| 86 |
+
|
| 87 |
+
if invalid_indices:
|
| 88 |
+
f.write("Details of removed items:\n")
|
| 89 |
+
for idx, details in invalid_indices:
|
| 90 |
+
f.write(f" Index {idx}: {details}\n")
|
| 91 |
+
|
| 92 |
+
print(f"Cleaning complete. Results written to {log_file}")
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
main()
|
data_utils/convert_npz_to_mesh_rig.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
You can convert npz file back to obj(mesh) and txt(rig) files using this python script.
|
| 16 |
+
"""
|
| 17 |
+
import os
|
| 18 |
+
import numpy as np
|
| 19 |
+
import scipy.sparse as sp
|
| 20 |
+
|
| 21 |
+
def export_obj(vertices, faces, normals, output_path):
|
| 22 |
+
with open(output_path, 'w') as f:
|
| 23 |
+
for v in vertices:
|
| 24 |
+
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
|
| 25 |
+
for n in normals:
|
| 26 |
+
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
|
| 27 |
+
for i, face in enumerate(faces):
|
| 28 |
+
# OBJ format is 1-based, so we add 1 to all indices
|
| 29 |
+
f.write(f"f {face[0]+1}//{face[0]+1} {face[1]+1}//{face[1]+1} {face[2]+1}//{face[2]+1}\n")
|
| 30 |
+
|
| 31 |
+
def export_rig_txt(joints, bones, root_index, joint_names, skinning_weights, output_path):
|
| 32 |
+
"""
|
| 33 |
+
joints [joint_name] [x] [y] [z]
|
| 34 |
+
root [root_joint_name]
|
| 35 |
+
skin [vertex_index] [joint_name1] [weight1] [joint_name2] [weight2] ...
|
| 36 |
+
hier [parent_joint_name] [child_joint_name]
|
| 37 |
+
"""
|
| 38 |
+
n_joints = len(joints)
|
| 39 |
+
n_verts = skinning_weights.shape[0] # (n_vertex, n_joints)
|
| 40 |
+
|
| 41 |
+
with open(output_path, 'w') as f:
|
| 42 |
+
# 1) joints
|
| 43 |
+
for i in range(n_joints):
|
| 44 |
+
x, y, z = joints[i]
|
| 45 |
+
jn = joint_names[i]
|
| 46 |
+
f.write(f"joints {jn} {x} {y} {z}\n")
|
| 47 |
+
|
| 48 |
+
# 2) root
|
| 49 |
+
root_name = joint_names[root_index]
|
| 50 |
+
f.write(f"root {root_name}\n")
|
| 51 |
+
|
| 52 |
+
# 3) skin
|
| 53 |
+
for vidx in range(n_verts):
|
| 54 |
+
row_weights = skinning_weights[vidx]
|
| 55 |
+
non_zero_indices = np.where(row_weights != 0)[0]
|
| 56 |
+
if len(non_zero_indices) == 0:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
line_parts = [f"skin {vidx}"] # vertex_idx
|
| 60 |
+
for jidx in non_zero_indices:
|
| 61 |
+
w = row_weights[jidx]
|
| 62 |
+
jn = joint_names[jidx]
|
| 63 |
+
line_parts.append(jn)
|
| 64 |
+
line_parts.append(str(w))
|
| 65 |
+
|
| 66 |
+
f.write(" ".join(line_parts) + "\n")
|
| 67 |
+
|
| 68 |
+
# 4) hier
|
| 69 |
+
for p_idx, c_idx in bones:
|
| 70 |
+
p_name = joint_names[p_idx]
|
| 71 |
+
c_name = joint_names[c_idx]
|
| 72 |
+
f.write(f"hier {p_name} {c_name}\n")
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
|
| 76 |
+
data = np.load('articulation_xlv2_test.npz', allow_pickle=True)
|
| 77 |
+
data_list = data['arr_0']
|
| 78 |
+
|
| 79 |
+
print(f"Loaded {len(data_list)} data entries")
|
| 80 |
+
|
| 81 |
+
model_data = data_list[0]
|
| 82 |
+
print("Data keys:", model_data.keys())
|
| 83 |
+
# 'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'joint_names',
|
| 84 |
+
# 'skinning_weights_value', 'skinning_weights_row', 'skinning_weights_col', 'skinning_weights_shape'
|
| 85 |
+
|
| 86 |
+
vertices = model_data['vertices'] # (n_vertex, 3)
|
| 87 |
+
faces = model_data['faces'] # (n_faces, 3)
|
| 88 |
+
normals = model_data['normals'] # (n_vertex, 3)
|
| 89 |
+
joints = model_data['joints'] # (n_joints, 3)
|
| 90 |
+
bones = model_data['bones'] # (n_bones, 2)
|
| 91 |
+
root_index = model_data['root_index'] # int
|
| 92 |
+
joint_names = model_data['joint_names'] # list of str
|
| 93 |
+
uuid_str = model_data['uuid']
|
| 94 |
+
|
| 95 |
+
skin_val = model_data['skinning_weights_value']
|
| 96 |
+
skin_row = model_data['skinning_weights_row']
|
| 97 |
+
skin_col = model_data['skinning_weights_col']
|
| 98 |
+
skin_shape = model_data['skinning_weights_shape']
|
| 99 |
+
skin_sparse = sp.coo_matrix((skin_val, (skin_row, skin_col)), shape=skin_shape)
|
| 100 |
+
skinning_weights = skin_sparse.toarray() # (n_vertex, n_joints)
|
| 101 |
+
|
| 102 |
+
obj_path = f"{uuid_str}.obj"
|
| 103 |
+
export_obj(vertices, faces, normals, obj_path)
|
| 104 |
+
rig_txt_path = f"{uuid_str}.txt"
|
| 105 |
+
export_rig_txt(joints, bones, root_index, joint_names, skinning_weights, rig_txt_path)
|
| 106 |
+
|
| 107 |
+
print("Done!")
|
data_utils/data_loader.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import glob
|
| 16 |
+
import numpy as np
|
| 17 |
+
import trimesh
|
| 18 |
+
|
| 19 |
+
class DataLoader:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.joint_name_to_idx = {}
|
| 22 |
+
|
| 23 |
+
def load_rig_data(self, rig_path):
|
| 24 |
+
joints = []
|
| 25 |
+
joints_names = []
|
| 26 |
+
bones = []
|
| 27 |
+
|
| 28 |
+
with open(rig_path, 'r') as f:
|
| 29 |
+
for line in f:
|
| 30 |
+
parts = line.strip().split()
|
| 31 |
+
if parts[0] == 'joints':
|
| 32 |
+
joint_name = parts[1]
|
| 33 |
+
joint_pos = [float(parts[2]), float(parts[3]), float(parts[4])]
|
| 34 |
+
self.joint_name_to_idx[joint_name] = len(joints)
|
| 35 |
+
joints.append(joint_pos)
|
| 36 |
+
joints_names.append(joint_name)
|
| 37 |
+
elif parts[0] == 'root':
|
| 38 |
+
self.root_name = parts[1]
|
| 39 |
+
elif parts[0] == 'hier':
|
| 40 |
+
parent_joint = self.joint_name_to_idx[parts[1]]
|
| 41 |
+
child_joint = self.joint_name_to_idx[parts[2]]
|
| 42 |
+
bones.append([parent_joint, child_joint])
|
| 43 |
+
|
| 44 |
+
self.joints = np.array(joints)
|
| 45 |
+
self.bones = np.array(bones)
|
| 46 |
+
self.joints_names = joints_names
|
| 47 |
+
self.root_idx = None
|
| 48 |
+
if self.root_name is not None:
|
| 49 |
+
self.root_idx = self.joint_name_to_idx[self.root_name]
|
| 50 |
+
|
| 51 |
+
def load_mesh(self, mesh_path):
|
| 52 |
+
mesh = trimesh.load(mesh_path, process=False)
|
| 53 |
+
mesh.visual.vertex_colors[:, 3] = 100 # set transparency
|
| 54 |
+
self.mesh = mesh
|
| 55 |
+
|
| 56 |
+
# Compute the centroid normal of the mesh
|
| 57 |
+
v = self.mesh.vertices
|
| 58 |
+
xmin, ymin, zmin = v.min(axis=0)
|
| 59 |
+
xmax, ymax, zmax = v.max(axis=0)
|
| 60 |
+
self.bbox_center = np.array([(xmax + xmin)/2, (ymax + ymin)/2, (zmax + zmin)/2])
|
| 61 |
+
self.bbox_size = np.array([xmax - xmin, ymax - ymin, zmax - zmin])
|
| 62 |
+
self.bbox_scale = max(xmax - xmin, ymax - ymin, zmax - zmin)
|
| 63 |
+
|
| 64 |
+
normal = mesh.center_mass - self.bbox_center
|
| 65 |
+
normal = normal / (np.linalg.norm(normal)+1e-5)
|
| 66 |
+
|
| 67 |
+
# Choose axis order based on normal direction
|
| 68 |
+
if abs(normal[1]) > abs(normal[2]): # if Y component is dominant
|
| 69 |
+
self.axis_order = [0, 1, 2] # swapping Y and Z
|
| 70 |
+
else:
|
| 71 |
+
self.axis_order =[0, 2, 1] # keep default order
|
| 72 |
+
|
| 73 |
+
self.mesh.vertices = self.mesh.vertices[:, self.axis_order]
|
| 74 |
+
self.joints = self.joints[:, self.axis_order]
|
| 75 |
+
self.normalize_coordinates()
|
| 76 |
+
|
| 77 |
+
def normalize_coordinates(self):
|
| 78 |
+
|
| 79 |
+
# Compute scale and offset
|
| 80 |
+
scale = 1.0 / (self.bbox_scale+1e-5)
|
| 81 |
+
offset = -self.bbox_center
|
| 82 |
+
|
| 83 |
+
self.mesh.vertices = (self.mesh.vertices + offset) * scale
|
| 84 |
+
self.joints = (self.joints + offset) * scale
|
| 85 |
+
|
| 86 |
+
# Calculate appropriate radii based on the mean size
|
| 87 |
+
self.joint_radius = 0.01
|
| 88 |
+
self.bone_radius = 0.005
|
| 89 |
+
|
| 90 |
+
def query_mesh_rig(self):
|
| 91 |
+
|
| 92 |
+
input_dict = {"shape": self.mesh}
|
| 93 |
+
|
| 94 |
+
# Create joints as spheres
|
| 95 |
+
joint_meshes = []
|
| 96 |
+
for i, joint in enumerate(self.joints):
|
| 97 |
+
|
| 98 |
+
sphere = trimesh.creation.icosphere(
|
| 99 |
+
radius=self.joint_radius, subdivisions=2
|
| 100 |
+
)
|
| 101 |
+
sphere.apply_translation(joint)
|
| 102 |
+
if i == self.root_idx:
|
| 103 |
+
# root green
|
| 104 |
+
sphere.visual.vertex_colors = [0, 255, 0, 255]
|
| 105 |
+
else:
|
| 106 |
+
sphere.visual.vertex_colors = [0, 0, 255, 255]
|
| 107 |
+
|
| 108 |
+
joint_meshes.append(sphere)
|
| 109 |
+
input_dict["joint_meshes"] = trimesh.util.concatenate(joint_meshes)
|
| 110 |
+
|
| 111 |
+
# Create bones as cylinders
|
| 112 |
+
bone_meshes = []
|
| 113 |
+
for bone in self.bones:
|
| 114 |
+
start, end = self.joints[bone[0]], self.joints[bone[1]]
|
| 115 |
+
cylinder = trimesh.creation.cylinder(radius=self.bone_radius, segment=np.array([[0, 0, 0], end - start]))
|
| 116 |
+
cylinder.apply_translation(start)
|
| 117 |
+
cylinder.visual.vertex_colors = [255, 0, 0, 255] #[0, 0, 255, 255] # blue
|
| 118 |
+
bone_meshes.append(cylinder)
|
| 119 |
+
input_dict["bone_meshes"] = trimesh.util.concatenate(bone_meshes)
|
| 120 |
+
|
| 121 |
+
return input_dict
|
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.obj
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png
ADDED
|
Git LFS Details
|
data_utils/issue_data_list.txt
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0b1f1ccb-db41-5689-b363-fd8ca0145041
|
| 2 |
+
d4705a2d-2dbf-5175-9fd0-b0cc538b9c4d
|
| 3 |
+
12b3d88d-2845-57b7-b483-d3a766beeb0e
|
| 4 |
+
778505b7-63da-5c08-bad7-6935fcd73cec
|
| 5 |
+
35ed271f-e9d7-528f-b165-e25004ef802b
|
| 6 |
+
0096279cc46c4d1d8e8611e611e2418b
|
| 7 |
+
00ea25ccad8344cbaedc89d70bb75a49
|
| 8 |
+
08b617be44b6466584ba9624f857222c
|
| 9 |
+
0998722861ba489695ad8bd4456e76e6
|
| 10 |
+
0bd786e936774176ac474694b0f6f876
|
| 11 |
+
0c1a7657bea0421dadef56e2080f0297
|
| 12 |
+
1073c44309524810b6cd4cef2d6e8008
|
| 13 |
+
10b9c6e9bf214dc39476161dfe2eaa8a
|
| 14 |
+
147df2ee69df488eb6cb2f88f2f703bb
|
| 15 |
+
18ff6fa66b0d483a8758e4602e5b70b0
|
| 16 |
+
1cf88736c59a43c88ba7dac44c929dab
|
| 17 |
+
1e9544eea98d417db87347dcc16cb69e
|
| 18 |
+
21a4bc038cbd415b8e09566148c87c46
|
| 19 |
+
2809e172066d4140b1ddc9356490191a
|
| 20 |
+
28483d55555f433d8fde4ba141ad5271
|
| 21 |
+
31829af6c72146519d348a6d4d2bcc8b
|
| 22 |
+
32202338cd5c40beace31deeacd598e5
|
| 23 |
+
37fe21828c37413986a07a1bf8c75c93
|
| 24 |
+
3857965c400c47c9a846c01eb1f36ed5
|
| 25 |
+
404e622bdfd14ab693640ff86c131973
|
| 26 |
+
44f8486a0b2c4f9489fc3912b2dcf880
|
| 27 |
+
49580a36b07d47808aa91db6e2b9fcdd
|
| 28 |
+
4db51555e8fd48a0905ecee93730f863
|
| 29 |
+
57a9d6f9fec7430bae67d7d7a9bfdd2c
|
| 30 |
+
593eeb44d67c49499d3580d908b9f5cd
|
| 31 |
+
5a571bea2d0c4ad5b2cc912c3dc37a59
|
| 32 |
+
5cd1f275bdb34d939ffaa07a641a2eef
|
| 33 |
+
60ab9787fde64199ab59b728276b5cd8
|
| 34 |
+
63453d744e3844d48bc9a7bedfe586a7
|
| 35 |
+
6caf784e33084b1389fdea4043560d3f
|
| 36 |
+
725ce5eae96b4602a3b8a30f73dcbc4c
|
| 37 |
+
7f9c3d9ccbd949449f25f3711780c1e7
|
| 38 |
+
80ff2e88de2144bbb21d231db5a02000
|
| 39 |
+
835174fcce4a4969851ca1846b92036a
|
| 40 |
+
85b73c92393e453faf0f7ec82d40720e
|
| 41 |
+
860911c447744c0396b618db994c535e
|
| 42 |
+
86d6d90704ff4e9c8fc0f0751bd837a2
|
| 43 |
+
934b27da5e4249978bfa9c190ec01f9a
|
| 44 |
+
968aecc8c38246f8af3d0d7fa169ca8f
|
| 45 |
+
9fc1cb45c8404517aa8cee3bb47c14fd
|
| 46 |
+
a65a935fd54b4159a2687bffef7cbf81
|
| 47 |
+
af2f7b1678ea4194a9b8235e7dfd23b3
|
| 48 |
+
b4cd213509ec4dcba41a280b4b013e63
|
| 49 |
+
be7a64227e1f4f13b86389edc4926dfa
|
| 50 |
+
bff3cd47d0574f73980b3af9f7790c58
|
| 51 |
+
c8ac24a9bf2647fb9e7565eaf3a28558
|
| 52 |
+
cc1f905b148c4378ad46a40da72e839f
|
| 53 |
+
ce50fe2e6a654a3bafab950c0f101e59
|
| 54 |
+
d270505df059467e8fa17974f075f3cf
|
| 55 |
+
d476d6bfc0364001a6cc73877a59ca65
|
| 56 |
+
d9a5b67b5c9142e984f76b1afec1939b
|
| 57 |
+
da9cb8ac53274b9bbd9467b7d83c85fb
|
| 58 |
+
dc48f3ab2b2844eba788898509a52806
|
| 59 |
+
e1817fcc5d614723bcb1f49491fe3ed0
|
| 60 |
+
f1fbc33234374c3a911148a453399186
|
| 61 |
+
faab16de19484746a4716cb00b738f8e
|
| 62 |
+
fdb767e69a0748c6bcdfe8764772c0d4
|
| 63 |
+
ff8ec56b0c664b438d36e84882b304f4
|
| 64 |
+
03ea3bf9d47e4e5789d027279e6edbbb
|
| 65 |
+
064a05ca3df84e3fbf900f9a1df75577
|
| 66 |
+
0ada42e959504b47ba58ca331a8d8549
|
| 67 |
+
112ae8160af54eeea6b2483b903634f4
|
| 68 |
+
156d6ab3d495476c997887c092aff781
|
| 69 |
+
1c92543b1e9245e0a2c1e3770a0e3d11
|
| 70 |
+
1e041df547e64db9aaa8d79218d880a8
|
| 71 |
+
1e34fd79cbb24db4952db6e9642881d3
|
| 72 |
+
1ec08e1e74d04354ac7085c004b01c2c
|
| 73 |
+
20dd7f7bdc9a4c36aef491f12afa14d8
|
| 74 |
+
242e99d9fe2f4eec91841fd3e8b01021
|
| 75 |
+
27dbf22159a5464687f4ed9b347257d3
|
| 76 |
+
28647ae054d74d2e9cac4a3dda31bb55
|
| 77 |
+
29ff70f5772747f89b0db4aae9c0ade6
|
| 78 |
+
2b03620bba824c1ea67945abd5c043f2
|
| 79 |
+
314d74658df6431ea50bede8512882cc
|
| 80 |
+
38f052a2027346e2943b4c76d2572415
|
| 81 |
+
3dbaadb244e44f59b5a6b7490aac6883
|
| 82 |
+
400dbd97e4e6429cab24fab8b5a3d845
|
| 83 |
+
41790f8edba642ffa281a0660f318db4
|
| 84 |
+
4c60ff4ebef241deae699ec8d2de86b5
|
| 85 |
+
5de63c02a4374605acb69691450e6653
|
| 86 |
+
65df530434624400b030da4579baa4b6
|
| 87 |
+
66c66c960e1c4b3aab5f2792f5e71add
|
| 88 |
+
6abf66991f584f1ba45d7297f3a128d4
|
| 89 |
+
6dd6b05e20604f478d9fd868528b275f
|
| 90 |
+
6f76008a68074d2bb59a0189f558ae34
|
| 91 |
+
8bb433dfbef3479cbaa3bcdf63b5b6a2
|
| 92 |
+
9338c7dbf4054c608c17353358cdb7c6
|
| 93 |
+
9544bb7b09874f13a5ecd0429379cbd8
|
| 94 |
+
95d2df27650f4beb8d208a21db7366d9
|
| 95 |
+
96d50c0f7f6a40ad9e5ae39537d1062e
|
| 96 |
+
9e7e71c08e5b4ff9b510afbfb2067152
|
| 97 |
+
a6cce2749dfb4b4d89c0dc3460ea9d3b
|
| 98 |
+
ab7e81a8a26d43ecb3131729a999ddcd
|
| 99 |
+
adae06ba4b7a4cbeab892957bc40331b
|
| 100 |
+
ba46772fa0234625832da0582c2f615c
|
| 101 |
+
c4f57ce4bc2b4c46a32414515ba991e9
|
| 102 |
+
cf09886dc98f4666bed77d6b51a4ef67
|
| 103 |
+
cfde2bfa5c634a788c2c4c4480f53ba7
|
| 104 |
+
d0008363ca6c4ea9976494eff45e90bb
|
| 105 |
+
d403eef8a45d485e905b968cc0a1670a
|
| 106 |
+
dc8d45c7ae7f453e9f861c79a40d9265
|
| 107 |
+
eb8e71b3a22f4e719d8157831c408a6e
|
| 108 |
+
ed896088728f4779b2fd9aa7f527e880
|
| 109 |
+
f06a196aea294b0fa05dee4be971a12c
|
| 110 |
+
f3e1bd29da234c8e89e0f208487fe31c
|
| 111 |
+
f84ffc38cbb9400ca31be98fe89abb01
|
| 112 |
+
fa31faff8ec04fa49e72e6266dc14cc4
|
| 113 |
+
fb6bd558e5ff4d3b8709a39d6280460b
|
| 114 |
+
808f9ffa-c14a-5d78-b8bf-197bc1f0b29c
|
| 115 |
+
e1740d44-9be4-58cf-a3e6-f8208b9cdfc6
|
| 116 |
+
4acf0253-00b8-5cca-be94-1f2af5bd72ba
|
| 117 |
+
0c94fe68-2983-52db-822e-6ea63bd54f65
|
| 118 |
+
ff9b4de9-a702-5221-bc26-f0c7ec8c4c51
|
| 119 |
+
b927ce627b6841a688067331853302d6
|
| 120 |
+
ccfad91e-e66d-5cc3-aff8-99f5b3a824fd
|
| 121 |
+
25434b7c-4ab4-58cd-900f-aa1bfcf53233
|
| 122 |
+
23d9764b-5035-5025-aae1-2788c1942a7c
|
| 123 |
+
ecbc08ea-5f9d-5d2f-a496-77ec128bd3fe
|
data_utils/pyrender_wrapper.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/lab4d-org/lab4d
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import pyrender
|
| 7 |
+
import trimesh
|
| 8 |
+
from pyrender import (
|
| 9 |
+
IntrinsicsCamera,
|
| 10 |
+
Mesh,
|
| 11 |
+
Node,
|
| 12 |
+
Scene,
|
| 13 |
+
OffscreenRenderer,
|
| 14 |
+
MetallicRoughnessMaterial,
|
| 15 |
+
RenderFlags
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
| 19 |
+
|
| 20 |
+
def look_at(eye, center, up):
|
| 21 |
+
"""Create a look-at (view) matrix."""
|
| 22 |
+
f = np.array(center, dtype=np.float32) - np.array(eye, dtype=np.float32)
|
| 23 |
+
f /= np.linalg.norm(f)
|
| 24 |
+
|
| 25 |
+
u = np.array(up, dtype=np.float32)
|
| 26 |
+
u /= np.linalg.norm(u)
|
| 27 |
+
|
| 28 |
+
s = np.cross(f, u)
|
| 29 |
+
u = np.cross(s, f)
|
| 30 |
+
|
| 31 |
+
m = np.identity(4, dtype=np.float32)
|
| 32 |
+
m[0, :3] = s
|
| 33 |
+
m[1, :3] = u
|
| 34 |
+
m[2, :3] = -f
|
| 35 |
+
m[:3, 3] = -np.matmul(m[:3, :3], np.array(eye, dtype=np.float32))
|
| 36 |
+
|
| 37 |
+
return m
|
| 38 |
+
|
| 39 |
+
class PyRenderWrapper:
|
| 40 |
+
def __init__(self, image_size=(1024, 1024)) -> None:
|
| 41 |
+
# renderer
|
| 42 |
+
self.image_size = image_size
|
| 43 |
+
render_size = max(image_size)
|
| 44 |
+
self.r = OffscreenRenderer(render_size, render_size)
|
| 45 |
+
self.intrinsics = IntrinsicsCamera(
|
| 46 |
+
render_size, render_size, render_size / 2, render_size / 2
|
| 47 |
+
)
|
| 48 |
+
# light
|
| 49 |
+
self.light_pose = np.eye(4)
|
| 50 |
+
self.set_light_topdown()
|
| 51 |
+
self.direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=5.0)
|
| 52 |
+
self.material = MetallicRoughnessMaterial(
|
| 53 |
+
roughnessFactor=0.75, metallicFactor=0.75, alphaMode="BLEND"
|
| 54 |
+
)
|
| 55 |
+
self.init_camera()
|
| 56 |
+
|
| 57 |
+
def init_camera(self):
|
| 58 |
+
self.flip_pose = np.eye(4)
|
| 59 |
+
self.set_camera(np.eye(4))
|
| 60 |
+
|
| 61 |
+
def set_camera(self, scene_to_cam):
|
| 62 |
+
# object to camera transforms
|
| 63 |
+
self.scene_to_cam = self.flip_pose @ scene_to_cam
|
| 64 |
+
|
| 65 |
+
def set_light_topdown(self, gl=False):
|
| 66 |
+
# top down light, slightly closer to the camera
|
| 67 |
+
if gl:
|
| 68 |
+
rot = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0]
|
| 69 |
+
else:
|
| 70 |
+
rot = cv2.Rodrigues(np.asarray([np.pi / 2, 0, 0]))[0]
|
| 71 |
+
self.light_pose[:3, :3] = rot
|
| 72 |
+
|
| 73 |
+
def align_light_to_camera(self):
|
| 74 |
+
self.light_pose = np.linalg.inv(self.scene_to_cam)
|
| 75 |
+
|
| 76 |
+
def set_intrinsics(self, intrinsics):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
intrinsics: (4,) fx,fy,px,py
|
| 80 |
+
"""
|
| 81 |
+
self.intrinsics = IntrinsicsCamera(
|
| 82 |
+
intrinsics[0], intrinsics[1], intrinsics[2], intrinsics[3]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def get_cam_to_scene(self):
|
| 86 |
+
cam_to_scene = np.eye(4)
|
| 87 |
+
cam_to_scene[:3, :3] = self.scene_to_cam[:3, :3].T
|
| 88 |
+
cam_to_scene[:3, 3] = -self.scene_to_cam[:3, :3].T @ self.scene_to_cam[:3, 3]
|
| 89 |
+
return cam_to_scene
|
| 90 |
+
|
| 91 |
+
def set_camera_view(self, angle, bbox_center, distance=2.0):
|
| 92 |
+
# Calculate camera position based on angle and distance from bounding box center
|
| 93 |
+
camera_position = bbox_center + distance * np.array([np.sin(angle), 0, np.cos(angle)], dtype=np.float32)
|
| 94 |
+
look_at_matrix = look_at(camera_position, bbox_center, [0, 1, 0])
|
| 95 |
+
self.scene_to_cam = look_at_matrix @ self.flip_pose
|
| 96 |
+
|
| 97 |
+
def render(self, input_dict):
|
| 98 |
+
# Create separate scenes for transparent objects (mesh) and solid objects (joints and bones)
|
| 99 |
+
scene_transparent = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0]) * 0.1)
|
| 100 |
+
scene_solid = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0]) * 0.1)
|
| 101 |
+
|
| 102 |
+
mesh_pyrender = Mesh.from_trimesh(input_dict["shape"], smooth=False)
|
| 103 |
+
mesh_pyrender.primitives[0].material = self.material
|
| 104 |
+
scene_transparent.add(mesh_pyrender, pose=np.eye(4), name="shape")
|
| 105 |
+
|
| 106 |
+
if "joint_meshes" in input_dict:
|
| 107 |
+
joints_pyrender = Mesh.from_trimesh(input_dict["joint_meshes"], smooth=False)
|
| 108 |
+
joints_pyrender.primitives[0].material = self.material
|
| 109 |
+
scene_solid.add(joints_pyrender, pose=np.eye(4), name="joints")
|
| 110 |
+
|
| 111 |
+
if "bone_meshes" in input_dict:
|
| 112 |
+
bones_pyrender = Mesh.from_trimesh(input_dict["bone_meshes"], smooth=False)
|
| 113 |
+
bones_pyrender.primitives[0].material = self.material
|
| 114 |
+
scene_solid.add(bones_pyrender, pose=np.eye(4), name="bones")
|
| 115 |
+
|
| 116 |
+
# Camera for both scenes
|
| 117 |
+
scene_transparent.add(self.intrinsics, pose=self.get_cam_to_scene())
|
| 118 |
+
scene_solid.add(self.intrinsics, pose=self.get_cam_to_scene())
|
| 119 |
+
|
| 120 |
+
# Light for both scenes
|
| 121 |
+
scene_transparent.add(self.direc_l, pose=self.light_pose)
|
| 122 |
+
scene_solid.add(self.direc_l, pose=self.light_pose)
|
| 123 |
+
|
| 124 |
+
# Render transparent scene first
|
| 125 |
+
color_transparent, depth_transparent = self.r.render(scene_transparent)
|
| 126 |
+
|
| 127 |
+
# Render solid scene on top
|
| 128 |
+
color_solid, depth_solid = self.r.render(scene_solid)
|
| 129 |
+
|
| 130 |
+
# Combine the two scenes
|
| 131 |
+
color_combined = np.where(depth_solid[..., np.newaxis] == 0, color_transparent, color_solid)
|
| 132 |
+
|
| 133 |
+
return color_combined, depth_solid
|
| 134 |
+
def delete(self):
|
| 135 |
+
self.r.delete()
|
data_utils/read_npz.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import numpy as np
|
| 15 |
+
import scipy.sparse as sp
|
| 16 |
+
|
| 17 |
+
# Load the NPZ file
|
| 18 |
+
data = np.load('articulation_xlv2_test.npz', allow_pickle=True)
|
| 19 |
+
data_list = data['arr_0']
|
| 20 |
+
|
| 21 |
+
print(f"Loaded {len(data_list)} data entries")
|
| 22 |
+
print(f"Data keys: {data_list[0].keys()}")
|
| 23 |
+
# 'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'pc_w_norm', 'joint_names', 'skinning_weights_value',
|
| 24 |
+
# 'skinning_weights_row', 'skinning_weights_col', 'skinning_weights_shape'
|
| 25 |
+
|
| 26 |
+
data = data_list[0] # check the first data
|
| 27 |
+
|
| 28 |
+
vertices = data['vertices'] # (n_vertex, 3)
|
| 29 |
+
faces = data['faces'] # (n_faces, 3)
|
| 30 |
+
normals = data['normals'] # (n_vertex, 3)
|
| 31 |
+
joints = data['joints'] # (n_joints, 3)
|
| 32 |
+
bones = data['bones'] # (n_bones, 2)
|
| 33 |
+
pc_w_norm = data['pc_w_norm'] # (8192, 6)
|
| 34 |
+
|
| 35 |
+
# Extract the sparse skinning weights components
|
| 36 |
+
skinning_data = data['skinning_weights_value']
|
| 37 |
+
skinning_rows = data['skinning_weights_row']
|
| 38 |
+
skinning_cols = data['skinning_weights_col']
|
| 39 |
+
skinning_shape = data['skinning_weights_shape']
|
| 40 |
+
|
| 41 |
+
skinning_sparse = sp.coo_matrix((skinning_data, (skinning_rows, skinning_cols)), shape=skinning_shape)
|
| 42 |
+
skinning_weights = skinning_sparse.toarray() # (n_vertex, n_joints)
|
| 43 |
+
|
data_utils/read_rig_mesh_from_glb.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Blender script for extracting rig (.txt) and mesh (.obj) from glbs.
|
| 17 |
+
This code currently supports GLB files only, but it can be easily modified to load other formats (e.g., FBX, DAE) with minimal changes.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import bpy
|
| 21 |
+
import os
|
| 22 |
+
import re
|
| 23 |
+
import json
|
| 24 |
+
import pickle
|
| 25 |
+
|
| 26 |
+
def get_hierarchy_root_joint(joint):
|
| 27 |
+
"""
|
| 28 |
+
Function to find the top parent joint node from the given
|
| 29 |
+
'joint' Blender node (armature bone).
|
| 30 |
+
"""
|
| 31 |
+
root_joint = joint
|
| 32 |
+
while root_joint.parent is not None:
|
| 33 |
+
root_joint = root_joint.parent
|
| 34 |
+
return root_joint
|
| 35 |
+
|
| 36 |
+
def get_meshes_and_armatures():
|
| 37 |
+
"""
|
| 38 |
+
Function to get all meshes and armatures in the scene
|
| 39 |
+
"""
|
| 40 |
+
default_objects = ['Cube', 'Light', 'Camera', 'Icosphere']
|
| 41 |
+
for obj_name in default_objects:
|
| 42 |
+
if obj_name in bpy.data.objects:
|
| 43 |
+
bpy.data.objects.remove(bpy.data.objects[obj_name], do_unlink=True)
|
| 44 |
+
|
| 45 |
+
meshes = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
|
| 46 |
+
armatures = [obj for obj in bpy.context.scene.objects if obj.type == 'ARMATURE']
|
| 47 |
+
return meshes, armatures
|
| 48 |
+
|
| 49 |
+
def get_joint_dict(root):
|
| 50 |
+
"""
|
| 51 |
+
Function to create a dictionary of joints from the root joint
|
| 52 |
+
"""
|
| 53 |
+
joint_pos = {}
|
| 54 |
+
def traverse_bone(bone):
|
| 55 |
+
joint_pos[bone.name] = {
|
| 56 |
+
'pos': bone.head_local,
|
| 57 |
+
'pa': bone.parent.name if bone.parent else 'None',
|
| 58 |
+
'ch': [child.name for child in bone.children]
|
| 59 |
+
}
|
| 60 |
+
for child in bone.children:
|
| 61 |
+
traverse_bone(child)
|
| 62 |
+
|
| 63 |
+
traverse_bone(root)
|
| 64 |
+
return joint_pos
|
| 65 |
+
|
| 66 |
+
def record_info(root, joint_dict, meshes, mesh_vert_offsets, file_info):
|
| 67 |
+
"""
|
| 68 |
+
- root: root joint
|
| 69 |
+
- joint_dict
|
| 70 |
+
- meshes
|
| 71 |
+
- mesh_vert_offsets: for multi-geometry
|
| 72 |
+
- file_info
|
| 73 |
+
"""
|
| 74 |
+
skin_records = {}
|
| 75 |
+
|
| 76 |
+
def replace_special_characters(name):
|
| 77 |
+
return re.sub(r'\W+', '_', name)
|
| 78 |
+
|
| 79 |
+
for key, val in joint_dict.items():
|
| 80 |
+
modified_key = replace_special_characters(key)
|
| 81 |
+
file_info.write(f'joints {modified_key} {val["pos"][0]:.8f} {val["pos"][1]:.8f} {val["pos"][2]:.8f}\n')
|
| 82 |
+
file_info.write(f'root {replace_special_characters(root.name)}\n')
|
| 83 |
+
|
| 84 |
+
for mesh_index, mesh in enumerate(meshes):
|
| 85 |
+
vert_offset = mesh_vert_offsets[mesh_index]
|
| 86 |
+
if mesh.type == 'MESH':
|
| 87 |
+
for vtx in mesh.data.vertices:
|
| 88 |
+
weights = {}
|
| 89 |
+
for group in vtx.groups:
|
| 90 |
+
bone_name = replace_special_characters(mesh.vertex_groups[group.group].name)
|
| 91 |
+
weights[bone_name] = group.weight
|
| 92 |
+
|
| 93 |
+
global_vertex_index = vert_offset + vtx.index
|
| 94 |
+
|
| 95 |
+
skin_record = f"skin {global_vertex_index} " + " ".join(f"{bone} {weight:.4f}" for bone, weight in weights.items())
|
| 96 |
+
|
| 97 |
+
if global_vertex_index not in skin_records:
|
| 98 |
+
skin_records[global_vertex_index] = skin_record
|
| 99 |
+
file_info.write(skin_record + "\n")
|
| 100 |
+
|
| 101 |
+
for key, val in joint_dict.items():
|
| 102 |
+
if val['pa'] != 'None':
|
| 103 |
+
parent_name = replace_special_characters(val['pa'])
|
| 104 |
+
child_name = replace_special_characters(key)
|
| 105 |
+
file_info.write(f'hier {parent_name} {child_name}\n')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def record_obj(meshes, file_obj):
|
| 109 |
+
vert_offset = 0
|
| 110 |
+
norm_offset = 0
|
| 111 |
+
mesh_vert_offsets = []
|
| 112 |
+
|
| 113 |
+
for mesh in meshes:
|
| 114 |
+
mesh_vert_offsets.append(vert_offset)
|
| 115 |
+
bpy.context.view_layer.objects.active = mesh
|
| 116 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
| 117 |
+
|
| 118 |
+
# vertex
|
| 119 |
+
for v in mesh.data.vertices:
|
| 120 |
+
file_obj.write(f"v {v.co[0]} {v.co[1]} {v.co[2]}\n")
|
| 121 |
+
file_obj.write("\n")
|
| 122 |
+
|
| 123 |
+
# normal
|
| 124 |
+
for vn in mesh.data.vertices:
|
| 125 |
+
normal = vn.normal
|
| 126 |
+
file_obj.write(f"vn {normal[0]} {normal[1]} {normal[2]}\n")
|
| 127 |
+
file_obj.write("\n")
|
| 128 |
+
|
| 129 |
+
# face
|
| 130 |
+
for poly in mesh.data.polygons:
|
| 131 |
+
verts = [v + 1 + vert_offset for v in poly.vertices]
|
| 132 |
+
file_obj.write(f"f {verts[0]}//{verts[0]} {verts[1]}//{verts[1]} {verts[2]}//{verts[2]}\n")
|
| 133 |
+
|
| 134 |
+
vert_count = len(mesh.data.vertices)
|
| 135 |
+
vert_offset += vert_count
|
| 136 |
+
norm_offset += vert_count
|
| 137 |
+
|
| 138 |
+
return mesh_vert_offsets
|
| 139 |
+
|
| 140 |
+
def process_glb(glb_path, rigs_dir, meshes_dir):
|
| 141 |
+
base_name = os.path.splitext(os.path.basename(glb_path))[0]
|
| 142 |
+
|
| 143 |
+
obj_name = os.path.join(meshes_dir, f'{base_name}.obj')
|
| 144 |
+
info_name = os.path.join(rigs_dir, f'{base_name}.txt')
|
| 145 |
+
|
| 146 |
+
# Skip processing if rig info file already exists
|
| 147 |
+
if os.path.exists(info_name):
|
| 148 |
+
print(f"{info_name} already exists. Skipping...")
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
if os.path.exists(obj_name):
|
| 152 |
+
print(f"{obj_name} already exists. Skipping...")
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
bpy.ops.wm.read_factory_settings(use_empty=True)
|
| 156 |
+
bpy.ops.import_scene.gltf(filepath=glb_path)
|
| 157 |
+
|
| 158 |
+
meshes, armatures = get_meshes_and_armatures()
|
| 159 |
+
|
| 160 |
+
if not armatures:
|
| 161 |
+
print(f"No armatures found in {glb_path}. Skipping...")
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
root = armatures[0].data.bones[0]
|
| 165 |
+
root_name = get_hierarchy_root_joint(root)
|
| 166 |
+
joint_dict = get_joint_dict(root_name)
|
| 167 |
+
|
| 168 |
+
# save meshes
|
| 169 |
+
with open(obj_name, 'w') as file_obj:
|
| 170 |
+
mesh_vert_offsets = record_obj(meshes, file_obj)
|
| 171 |
+
|
| 172 |
+
# save rigs
|
| 173 |
+
with open(info_name, 'w') as file_info:
|
| 174 |
+
record_info(root_name, joint_dict, meshes, mesh_vert_offsets, file_info)
|
| 175 |
+
|
| 176 |
+
print(f"Processed {glb_path}")
|
| 177 |
+
|
| 178 |
+
if __name__ == '__main__':
|
| 179 |
+
|
| 180 |
+
src_dir = 'glbs'
|
| 181 |
+
rigs_dir = 'rigs'
|
| 182 |
+
meshes_dir = 'meshes'
|
| 183 |
+
# Ensure rigs directory exists
|
| 184 |
+
if not os.path.exists(rigs_dir):
|
| 185 |
+
os.makedirs(rigs_dir)
|
| 186 |
+
if not os.path.exists(meshes_dir):
|
| 187 |
+
os.makedirs(meshes_dir)
|
| 188 |
+
|
| 189 |
+
glb_paths = [os.path.join(src_dir, file) for file in os.listdir(src_dir) if file.endswith('.glb')]
|
| 190 |
+
|
| 191 |
+
print(len(glb_paths))
|
| 192 |
+
|
| 193 |
+
for glb_path in glb_paths:
|
| 194 |
+
try:
|
| 195 |
+
process_glb(glb_path, rigs_dir, meshes_dir)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
with open('error.txt', 'a') as error_file:
|
| 198 |
+
error_file.write(f"{glb_path}: {str(e)}\n")
|
data_utils/render_data.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import numpy as np
|
| 16 |
+
import cv2
|
| 17 |
+
|
| 18 |
+
from pyrender_wrapper import PyRenderWrapper
|
| 19 |
+
from data_loader import DataLoader
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
loader = DataLoader()
|
| 23 |
+
|
| 24 |
+
raw_size = (960, 960)
|
| 25 |
+
renderer = PyRenderWrapper(raw_size)
|
| 26 |
+
|
| 27 |
+
output_dir = 'render_results'
|
| 28 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
rig_path = 'examples/0a59c5ffa4a1476bac6d540b79947f31.txt'
|
| 31 |
+
mesh_path = rig_path.replace('.txt', '.obj')
|
| 32 |
+
|
| 33 |
+
filename = os.path.splitext(os.path.basename(rig_path))[0]
|
| 34 |
+
|
| 35 |
+
loader.load_rig_data(rig_path)
|
| 36 |
+
loader.load_mesh(mesh_path)
|
| 37 |
+
input_dict = loader.query_mesh_rig()
|
| 38 |
+
|
| 39 |
+
angles = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 40 |
+
|
| 41 |
+
bbox_center = loader.mesh.bounding_box.centroid
|
| 42 |
+
bbox_size = loader.mesh.bounding_box.extents
|
| 43 |
+
distance = np.max(bbox_size) * 2
|
| 44 |
+
|
| 45 |
+
subfolder_path = os.path.join(output_dir, filename)
|
| 46 |
+
|
| 47 |
+
os.makedirs(subfolder_path, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
for i, angle in enumerate(angles):
|
| 50 |
+
print(f"Rendering view at {np.degrees(angle)} degrees")
|
| 51 |
+
|
| 52 |
+
renderer.set_camera_view(angle, bbox_center, distance)
|
| 53 |
+
renderer.align_light_to_camera()
|
| 54 |
+
|
| 55 |
+
color = renderer.render(input_dict)[0]
|
| 56 |
+
|
| 57 |
+
output_filename = f"{filename}_view{i+1}.png"
|
| 58 |
+
output_filepath = os.path.join(subfolder_path, output_filename)
|
| 59 |
+
cv2.imwrite(output_filepath, color)
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
data_utils/save_npz.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
This python script shows how we process the meshes and rigs from the input folders and save them in a compressed npz file.
|
| 16 |
+
"""
|
| 17 |
+
import os
|
| 18 |
+
import numpy as np
|
| 19 |
+
import glob
|
| 20 |
+
import pickle
|
| 21 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 22 |
+
import skimage.measure
|
| 23 |
+
import trimesh
|
| 24 |
+
import mesh2sdf.core
|
| 25 |
+
import scipy.sparse as sp
|
| 26 |
+
|
| 27 |
+
def read_obj_file(file_path):
|
| 28 |
+
vertices = []
|
| 29 |
+
faces = []
|
| 30 |
+
normals = [] # Added normals list
|
| 31 |
+
|
| 32 |
+
with open(file_path, 'r') as file:
|
| 33 |
+
for line in file:
|
| 34 |
+
if line.startswith('v '):
|
| 35 |
+
parts = line.split()[1:]
|
| 36 |
+
vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
|
| 37 |
+
elif line.startswith('vn '): # Added reading normals
|
| 38 |
+
parts = line.split()[1:]
|
| 39 |
+
normals.append([float(parts[0]), float(parts[1]), float(parts[2])])
|
| 40 |
+
elif line.startswith('f '):
|
| 41 |
+
parts = line.split()[1:]
|
| 42 |
+
# OBJ format is 1-based, we need 0-based for npz
|
| 43 |
+
face = [int(part.split('//')[0]) - 1 for part in parts]
|
| 44 |
+
faces.append(face)
|
| 45 |
+
|
| 46 |
+
return np.array(vertices), np.array(faces), np.array(normals)
|
| 47 |
+
|
| 48 |
+
def read_rig_file(file_path):
|
| 49 |
+
"""
|
| 50 |
+
Read rig from txt file, our format is the same as RigNet:
|
| 51 |
+
joints joint_name x y z
|
| 52 |
+
root root_joint_name
|
| 53 |
+
skin vertex_idx joint_name weight joint_name weight ...
|
| 54 |
+
hier parent_joint_name child_joint_name
|
| 55 |
+
"""
|
| 56 |
+
joints = []
|
| 57 |
+
bones = []
|
| 58 |
+
joint_names = []
|
| 59 |
+
|
| 60 |
+
joint_mapping = {}
|
| 61 |
+
joint_index = 0
|
| 62 |
+
|
| 63 |
+
skinning_data = {} # Dictionary to store vertex index -> [(joint_idx, weight), ...]
|
| 64 |
+
|
| 65 |
+
with open(file_path, 'r') as file:
|
| 66 |
+
lines = file.readlines()
|
| 67 |
+
|
| 68 |
+
for line in lines:
|
| 69 |
+
parts = line.split()
|
| 70 |
+
if line.startswith('joints'):
|
| 71 |
+
name = parts[1]
|
| 72 |
+
position = [float(parts[2]), float(parts[3]), float(parts[4])]
|
| 73 |
+
joints.append(position)
|
| 74 |
+
joint_names.append(name)
|
| 75 |
+
joint_mapping[name] = joint_index
|
| 76 |
+
joint_index += 1
|
| 77 |
+
elif line.startswith('hier'):
|
| 78 |
+
parent_joint = joint_mapping[parts[1]]
|
| 79 |
+
child_joint = joint_mapping[parts[2]]
|
| 80 |
+
bones.append([parent_joint, child_joint])
|
| 81 |
+
elif line.startswith('root'):
|
| 82 |
+
root = joint_mapping[parts[1]]
|
| 83 |
+
elif line.startswith('skin'):
|
| 84 |
+
vertex_idx = int(parts[1])
|
| 85 |
+
|
| 86 |
+
if vertex_idx not in skinning_data:
|
| 87 |
+
skinning_data[vertex_idx] = []
|
| 88 |
+
|
| 89 |
+
for i in range(2, len(parts), 2):
|
| 90 |
+
if i+1 < len(parts):
|
| 91 |
+
joint_name = parts[i]
|
| 92 |
+
weight = float(parts[i+1])
|
| 93 |
+
|
| 94 |
+
if joint_name in joint_mapping:
|
| 95 |
+
joint_idx = joint_mapping[joint_name]
|
| 96 |
+
skinning_data[vertex_idx].append((joint_idx, weight))
|
| 97 |
+
|
| 98 |
+
return np.array(joints), np.array(bones), root, joint_names, skinning_data
|
| 99 |
+
|
| 100 |
+
def convert_to_sparse_skinning(skinning_data, num_vertices, num_joints):
|
| 101 |
+
"""Convert skinning weights to sparse matrix format."""
|
| 102 |
+
rows = []
|
| 103 |
+
cols = []
|
| 104 |
+
data = []
|
| 105 |
+
|
| 106 |
+
for vertex_idx, weights in skinning_data.items():
|
| 107 |
+
for joint_idx, weight in weights:
|
| 108 |
+
rows.append(vertex_idx)
|
| 109 |
+
cols.append(joint_idx)
|
| 110 |
+
data.append(weight)
|
| 111 |
+
|
| 112 |
+
sparse_skinning = sp.coo_matrix((data, (rows, cols)), shape=(num_vertices, num_joints))
|
| 113 |
+
|
| 114 |
+
# Return as tuple of arrays which can be serialized
|
| 115 |
+
return (sparse_skinning.data, sparse_skinning.row, sparse_skinning.col, sparse_skinning.shape)
|
| 116 |
+
|
| 117 |
+
def normalize_to_unit_cube(vertices, normals=None, scale_factor=1.0):
|
| 118 |
+
min_coords = vertices.min(axis=0)
|
| 119 |
+
max_coords = vertices.max(axis=0)
|
| 120 |
+
center = (max_coords + min_coords) / 2.0
|
| 121 |
+
|
| 122 |
+
vertices -= center
|
| 123 |
+
scale = 1.0 / np.abs(vertices).max() * scale_factor
|
| 124 |
+
vertices *= scale
|
| 125 |
+
|
| 126 |
+
if normals is not None:
|
| 127 |
+
# Normalize each normal vector to unit length
|
| 128 |
+
norms = np.linalg.norm(normals, axis=1, keepdims=True)
|
| 129 |
+
normals = normals / (norms+1e-8)
|
| 130 |
+
|
| 131 |
+
return vertices, normals, center, scale
|
| 132 |
+
else:
|
| 133 |
+
return vertices, center, scale
|
| 134 |
+
|
| 135 |
+
def normalize_vertices(vertices, scale=0.9):
|
| 136 |
+
bbmin, bbmax = vertices.min(0), vertices.max(0)
|
| 137 |
+
center = (bbmin + bbmax) * 0.5
|
| 138 |
+
scale = 2.0 * scale / (bbmax - bbmin).max()
|
| 139 |
+
vertices = (vertices - center) * scale
|
| 140 |
+
return vertices, center, scale
|
| 141 |
+
|
| 142 |
+
def export_to_watertight(normalized_mesh, octree_depth: int = 7):
|
| 143 |
+
"""
|
| 144 |
+
Convert the non-watertight mesh to watertight.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
input_path (str): normalized path
|
| 148 |
+
octree_depth (int):
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
mesh(trimesh.Trimesh): watertight mesh
|
| 152 |
+
|
| 153 |
+
"""
|
| 154 |
+
size = 2 ** octree_depth
|
| 155 |
+
level = 2 / size
|
| 156 |
+
|
| 157 |
+
scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
|
| 158 |
+
|
| 159 |
+
sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
|
| 160 |
+
|
| 161 |
+
vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
|
| 162 |
+
|
| 163 |
+
# watertight mesh
|
| 164 |
+
vertices = vertices / size * 2 - 1 # -1 to 1
|
| 165 |
+
vertices = vertices / to_orig_scale + to_orig_center
|
| 166 |
+
mesh = trimesh.Trimesh(vertices, faces, normals=normals)
|
| 167 |
+
|
| 168 |
+
return mesh
|
| 169 |
+
|
| 170 |
+
def process_mesh_to_pc(mesh, marching_cubes = True, sample_num = 8192):
|
| 171 |
+
if marching_cubes:
|
| 172 |
+
mesh = export_to_watertight(mesh)
|
| 173 |
+
return_mesh = mesh
|
| 174 |
+
points, face_idx = mesh.sample(sample_num, return_index=True)
|
| 175 |
+
points, _, _ = normalize_to_unit_cube(points, scale_factor=0.9995)
|
| 176 |
+
normals = mesh.face_normals[face_idx]
|
| 177 |
+
|
| 178 |
+
pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
|
| 179 |
+
return pc_normal, return_mesh
|
| 180 |
+
|
| 181 |
+
def process_single_file(args):
|
| 182 |
+
mesh_file, rig_file = args
|
| 183 |
+
mesh_name = os.path.basename(mesh_file).split('.')[0]
|
| 184 |
+
rig_name = os.path.basename(rig_file).split('.')[0]
|
| 185 |
+
|
| 186 |
+
if mesh_name != rig_name:
|
| 187 |
+
print(f"Skipping files {mesh_file} and {rig_file} because their names do not match.")
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
vertices, faces, normals = read_obj_file(mesh_file)
|
| 191 |
+
|
| 192 |
+
joints, bones, root, joint_names, skinning_data = read_rig_file(rig_file)
|
| 193 |
+
|
| 194 |
+
# Normalize the mesh to the unit cube centered at the origin
|
| 195 |
+
vertices, normals, center, scale = normalize_to_unit_cube(vertices, normals, scale_factor=0.5)
|
| 196 |
+
|
| 197 |
+
# Apply the same transformation to joints
|
| 198 |
+
joints -= center
|
| 199 |
+
joints *= scale
|
| 200 |
+
|
| 201 |
+
# Create trimesh object for processing
|
| 202 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 203 |
+
|
| 204 |
+
# Process into point cloud with normals
|
| 205 |
+
pc_normal, _ = process_mesh_to_pc(mesh)
|
| 206 |
+
|
| 207 |
+
# Convert skinning data to sparse format
|
| 208 |
+
sparse_skinning = convert_to_sparse_skinning(skinning_data, len(vertices), len(joints))
|
| 209 |
+
|
| 210 |
+
return {
|
| 211 |
+
'vertices': vertices,
|
| 212 |
+
'faces': faces,
|
| 213 |
+
'normals': normals,
|
| 214 |
+
'joints': joints,
|
| 215 |
+
'bones': bones,
|
| 216 |
+
'root_index': root,
|
| 217 |
+
'uuid': mesh_name,
|
| 218 |
+
'pc_w_norm': pc_normal,
|
| 219 |
+
'joint_names': joint_names,
|
| 220 |
+
'skinning_weights_value': sparse_skinning[0], # values
|
| 221 |
+
'skinning_weights_rows': sparse_skinning[1], # row indices
|
| 222 |
+
'skinning_weights_cols': sparse_skinning[2], # column indices
|
| 223 |
+
'skinning_weights_shape': sparse_skinning[3] # shape of matrix
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
def process_files(mesh_folder, rig_folder, output_file, num_workers=8):
|
| 227 |
+
file_pairs = []
|
| 228 |
+
|
| 229 |
+
for root, _, files in os.walk(rig_folder):
|
| 230 |
+
for file in files:
|
| 231 |
+
if file.endswith('.txt'):
|
| 232 |
+
rig_file = os.path.join(root, file)
|
| 233 |
+
obj_base_name = os.path.splitext(file)[0]
|
| 234 |
+
mesh_file = os.path.join(mesh_folder, obj_base_name + '.obj')
|
| 235 |
+
if os.path.exists(mesh_file):
|
| 236 |
+
file_pairs.append((mesh_file, rig_file))
|
| 237 |
+
else:
|
| 238 |
+
print(f"Mesh file not found: {mesh_file}")
|
| 239 |
+
|
| 240 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 241 |
+
data_list = list(executor.map(process_single_file, file_pairs))
|
| 242 |
+
|
| 243 |
+
data_list = [data for data in data_list if data is not None]
|
| 244 |
+
|
| 245 |
+
np.savez_compressed(output_file, data_list, allow_pickle=True)
|
| 246 |
+
|
| 247 |
+
def main():
|
| 248 |
+
# Example usage
|
| 249 |
+
mesh_folder = 'meshes/'
|
| 250 |
+
rig_folder = 'rigs/'
|
| 251 |
+
output_file = 'results.npz'
|
| 252 |
+
|
| 253 |
+
process_files(mesh_folder, rig_folder, output_file)
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
main()
|
data_utils/update_npz_rm_issue_data.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def filter_npz_by_filenames(npz_path, txt_path, output_path):
|
| 18 |
+
|
| 19 |
+
data_list = np.load(npz_path, allow_pickle=True)['arr_0']
|
| 20 |
+
|
| 21 |
+
with open(txt_path, 'r') as f:
|
| 22 |
+
exclude_filenames = set(line.strip() for line in f if line.strip())
|
| 23 |
+
|
| 24 |
+
# Filter the data list
|
| 25 |
+
filtered_data = []
|
| 26 |
+
excluded_count = 0
|
| 27 |
+
|
| 28 |
+
for item in data_list:
|
| 29 |
+
|
| 30 |
+
filename = item['uuid']
|
| 31 |
+
|
| 32 |
+
if filename in exclude_filenames:
|
| 33 |
+
excluded_count += 1
|
| 34 |
+
print(filename)
|
| 35 |
+
else:
|
| 36 |
+
filtered_data.append(item)
|
| 37 |
+
|
| 38 |
+
# Save the filtered data
|
| 39 |
+
kept_count = len(filtered_data)
|
| 40 |
+
total_count = len(data_list)
|
| 41 |
+
print(f"Original items: {total_count}")
|
| 42 |
+
print(f"Kept items: {kept_count}")
|
| 43 |
+
print(f"Removed items: {excluded_count}")
|
| 44 |
+
|
| 45 |
+
print(f"Saving filtered data")
|
| 46 |
+
np.savez_compressed(output_path, filtered_data, allow_pickle=True)
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
issue_list = "data_utils/issue_data_list.txt" # Change this to your text file path
|
| 50 |
+
npz_path_train = "articulation_xlv2_train.npz" # Change this to your NPZ file path
|
| 51 |
+
output_path_train = "articulation_xlv2_train_update.npz"
|
| 52 |
+
npz_path_test = "articulation_xlv2_test.npz" # Change this to your NPZ file path
|
| 53 |
+
output_path_test = "articulation_xlv2_test_update.npz"
|
| 54 |
+
|
| 55 |
+
filter_npz_by_filenames(npz_path_train, issue_list, output_path_train)
|
| 56 |
+
filter_npz_by_filenames(npz_path_test, issue_list, output_path_test)
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
demo.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import torch
|
| 16 |
+
import trimesh
|
| 17 |
+
import argparse
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from trimesh import Scene
|
| 22 |
+
|
| 23 |
+
from accelerate import Accelerator
|
| 24 |
+
from accelerate.utils import set_seed
|
| 25 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
| 26 |
+
|
| 27 |
+
from skeleton_models.skeletongen import SkeletonGPT
|
| 28 |
+
from data_utils.save_npz import normalize_to_unit_cube
|
| 29 |
+
from utils.mesh_to_pc import MeshProcessor
|
| 30 |
+
from utils.save_utils import save_mesh, pred_joints_and_bones, save_skeleton_to_txt, save_args, \
|
| 31 |
+
merge_duplicate_joints_and_fix_bones, save_skeleton_obj, render_mesh_with_skeleton
|
| 32 |
+
|
| 33 |
+
class Dataset:
|
| 34 |
+
def __init__(self, input_list, input_pc_num = 8192, apply_marching_cubes = True, octree_depth = 7, output_dir = None):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.data = []
|
| 37 |
+
self.output_dir = output_dir
|
| 38 |
+
|
| 39 |
+
mesh_list = []
|
| 40 |
+
for input_path in input_list:
|
| 41 |
+
ext = os.path.splitext(input_path)[1].lower()
|
| 42 |
+
if ext in ['.ply', '.stl', '.obj']:
|
| 43 |
+
cur_data = trimesh.load(input_path, force='mesh')
|
| 44 |
+
mesh_list.append(cur_data)
|
| 45 |
+
else:
|
| 46 |
+
print(f"Unsupported file type: {ext}")
|
| 47 |
+
if apply_marching_cubes:
|
| 48 |
+
print("First apply Marching Cubes and then sample point cloud, need time...")
|
| 49 |
+
pc_list = MeshProcessor.convert_meshes_to_point_clouds(mesh_list, input_pc_num, apply_marching_cubes = apply_marching_cubes, octree_depth = octree_depth)
|
| 50 |
+
for input_path, cur_data, mesh in zip(input_list, pc_list, mesh_list):
|
| 51 |
+
self.data.append({'pc_normal': cur_data, 'faces': mesh.faces, 'vertices': mesh.vertices, 'file_name': os.path.splitext(os.path.basename(input_path))[0]})
|
| 52 |
+
print(f"dataset total data samples: {len(self.data)}")
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.data)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
data_dict = {}
|
| 59 |
+
data_dict['pc_normal'] = self.data[idx]['pc_normal']
|
| 60 |
+
# normalize pc coor
|
| 61 |
+
pc_coor = data_dict['pc_normal'][:, :3]
|
| 62 |
+
normals = data_dict['pc_normal'][:, 3:]
|
| 63 |
+
pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
|
| 64 |
+
|
| 65 |
+
data_dict['file_name'] = self.data[idx]['file_name']
|
| 66 |
+
pc_coor = pc_coor.astype(np.float32)
|
| 67 |
+
normals = normals.astype(np.float32)
|
| 68 |
+
|
| 69 |
+
point_cloud = trimesh.PointCloud(pc_coor)
|
| 70 |
+
point_cloud.metadata['normals'] = normals
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
point_cloud.export(os.path.join(self.output_dir, f"{data_dict['file_name']}.ply"))
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"fail to save point clouds: {e}")
|
| 76 |
+
|
| 77 |
+
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
|
| 78 |
+
data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
|
| 79 |
+
|
| 80 |
+
vertices = self.data[idx]['vertices']
|
| 81 |
+
faces = self.data[idx]['faces']
|
| 82 |
+
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
|
| 83 |
+
pc_center = (bounds[0] + bounds[1])[None, :] / 2
|
| 84 |
+
pc_scale = ((bounds[1] - bounds[0]).max() + 1e-5)
|
| 85 |
+
data_dict['transform_params'] = torch.tensor([
|
| 86 |
+
center[0], center[1], center[2],
|
| 87 |
+
scale,
|
| 88 |
+
pc_center[0][0], pc_center[0][1], pc_center[0][2],
|
| 89 |
+
pc_scale
|
| 90 |
+
], dtype=torch.float32)
|
| 91 |
+
data_dict['vertices'] = vertices
|
| 92 |
+
data_dict['faces']= faces
|
| 93 |
+
return data_dict
|
| 94 |
+
|
| 95 |
+
def get_args():
|
| 96 |
+
parser = argparse.ArgumentParser("SkeletonGPT", add_help=False)
|
| 97 |
+
|
| 98 |
+
parser.add_argument("--input_pc_num", default=8192, type=int)
|
| 99 |
+
parser.add_argument("--num_beams", default=1, type=int)
|
| 100 |
+
parser.add_argument('--input_dir', default=None, type=str, help="input mesh directory")
|
| 101 |
+
parser.add_argument('--input_path', default=None, type=str, help="input mesh path")
|
| 102 |
+
parser.add_argument("--output_dir", default="outputs", type=str)
|
| 103 |
+
parser.add_argument('--llm', default="facebook/opt-350m", type=str, help="The LLM backend")
|
| 104 |
+
parser.add_argument("--pad_id", default=-1, type=int, help="padding id")
|
| 105 |
+
parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space")
|
| 106 |
+
parser.add_argument("--n_max_bones", default=100, type=int, help="max number of bones")
|
| 107 |
+
parser.add_argument('--dataset_path', default="combine_256_updated", type=str, help="data path")
|
| 108 |
+
parser.add_argument("--seed", default=0, type=int)
|
| 109 |
+
parser.add_argument("--precision", default="fp16", type=str)
|
| 110 |
+
parser.add_argument("--batchsize_per_gpu", default=1, type=int)
|
| 111 |
+
parser.add_argument('--pretrained_weights', default=None, type=str)
|
| 112 |
+
parser.add_argument('--save_name', default="infer_results", type=str)
|
| 113 |
+
parser.add_argument("--save_render", default=False, action="store_true", help="save rendering results of mesh with skel")
|
| 114 |
+
parser.add_argument("--apply_marching_cubes", default=False, action="store_true")
|
| 115 |
+
parser.add_argument("--octree_depth", default=7, type=int)
|
| 116 |
+
parser.add_argument("--hier_order", default=False, action="store_true")
|
| 117 |
+
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
return args
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
args = get_args()
|
| 123 |
+
|
| 124 |
+
output_dir = f'{args.output_dir}/{args.save_name}'
|
| 125 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 126 |
+
save_args(args, output_dir)
|
| 127 |
+
|
| 128 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 129 |
+
accelerator = Accelerator(
|
| 130 |
+
kwargs_handlers=[kwargs],
|
| 131 |
+
mixed_precision=args.precision,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
model = SkeletonGPT(args).cuda()
|
| 135 |
+
|
| 136 |
+
if args.pretrained_weights is not None:
|
| 137 |
+
pkg = torch.load(args.pretrained_weights, map_location=torch.device("cpu"))
|
| 138 |
+
model.load_state_dict(pkg["model"])
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError("Pretrained weights must be provided.")
|
| 141 |
+
model.eval()
|
| 142 |
+
set_seed(args.seed)
|
| 143 |
+
|
| 144 |
+
# create dataset
|
| 145 |
+
if args.input_dir is not None:
|
| 146 |
+
input_list = sorted(os.listdir(args.input_dir))
|
| 147 |
+
input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.stl')]
|
| 148 |
+
dataset = Dataset(input_list, args.input_pc_num, args.apply_marching_cubes, args.octree_depth, output_dir)
|
| 149 |
+
elif args.input_path is not None:
|
| 150 |
+
dataset = Dataset([args.input_path], args.input_pc_num, args.apply_marching_cubes, args.octree_depth, output_dir)
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError("input_dir or input_path must be provided.")
|
| 153 |
+
|
| 154 |
+
dataloader = torch.utils.data.DataLoader(
|
| 155 |
+
dataset,
|
| 156 |
+
batch_size= 1,
|
| 157 |
+
drop_last = False,
|
| 158 |
+
shuffle = False,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
dataloader, model = accelerator.prepare(dataloader, model)
|
| 162 |
+
|
| 163 |
+
for curr_iter, batch_data_label in tqdm(enumerate(dataloader), total=len(dataloader)):
|
| 164 |
+
with accelerator.autocast():
|
| 165 |
+
pred_bone_coords = model.generate(batch_data_label)
|
| 166 |
+
|
| 167 |
+
# determine the output file name
|
| 168 |
+
file_name = os.path.basename(batch_data_label['file_name'][0])
|
| 169 |
+
pred_skel_filename = os.path.join(output_dir, f'{file_name}_skel.obj')
|
| 170 |
+
pred_rig_filename = os.path.join(output_dir, f"{file_name}_pred.txt")
|
| 171 |
+
mesh_filename = os.path.join(output_dir, f"{file_name}_mesh.obj")
|
| 172 |
+
|
| 173 |
+
transform_params = batch_data_label['transform_params'][0].cpu().numpy()
|
| 174 |
+
trans = transform_params[:3]
|
| 175 |
+
scale = transform_params[3]
|
| 176 |
+
pc_trans = transform_params[4:7]
|
| 177 |
+
pc_scale = transform_params[7]
|
| 178 |
+
vertices = batch_data_label['vertices'][0].cpu().numpy()
|
| 179 |
+
faces = batch_data_label['faces'][0].cpu().numpy()
|
| 180 |
+
|
| 181 |
+
skeleton = pred_bone_coords[0].cpu().numpy()
|
| 182 |
+
pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
|
| 183 |
+
|
| 184 |
+
# Post process: merge duplicate or nearby joints and deduplicate bones.
|
| 185 |
+
if args.hier_order:
|
| 186 |
+
pred_root_index = pred_bones[0][0]
|
| 187 |
+
pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones, root_index=pred_root_index)
|
| 188 |
+
else:
|
| 189 |
+
pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones)
|
| 190 |
+
pred_root_index = None
|
| 191 |
+
|
| 192 |
+
# when save rig to txt, denormalize the skeletons to the same scale with input meshes
|
| 193 |
+
pred_joints_denorm = pred_joints * pc_scale + pc_trans # first align with point cloud
|
| 194 |
+
pred_joints_denorm = pred_joints_denorm / scale + trans # then align with original mesh
|
| 195 |
+
|
| 196 |
+
save_skeleton_to_txt(pred_joints_denorm, pred_bones, pred_root_index, args.hier_order, vertices, pred_rig_filename)
|
| 197 |
+
|
| 198 |
+
# save skeletons
|
| 199 |
+
if args.hier_order:
|
| 200 |
+
save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, pred_root_index, use_cone=True)
|
| 201 |
+
else:
|
| 202 |
+
save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, use_cone=False)
|
| 203 |
+
|
| 204 |
+
# when saving mesh and rendering, use normalized vertices (-0.5,0.5)
|
| 205 |
+
vertices_norm = (vertices - trans) * scale
|
| 206 |
+
vertices_norm = (vertices_norm - pc_trans) / pc_scale
|
| 207 |
+
save_mesh(vertices_norm, faces, mesh_filename)
|
| 208 |
+
|
| 209 |
+
# render mesh w/ skeleton
|
| 210 |
+
if args.save_render:
|
| 211 |
+
if args.hier_order:
|
| 212 |
+
render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred', root_idx=pred_root_index)
|
| 213 |
+
else:
|
| 214 |
+
render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred')
|
demo.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CUDA_VISIBLE_DEVICES=0 python demo.py --input_dir ./examples \
|
| 2 |
+
--pretrained_weights skeleton_ckpt/checkpoint_trainonv2_hier.pth \
|
| 3 |
+
--save_name infer_results_demo_hier --input_pc_num 8192 \
|
| 4 |
+
--save_render --apply_marching_cubes --hier_order
|
download.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download
|
| 2 |
+
|
| 3 |
+
file_path = hf_hub_download(
|
| 4 |
+
repo_id="Maikou/Michelangelo",
|
| 5 |
+
filename="checkpoints/aligned_shape_latents/shapevae-256.ckpt",
|
| 6 |
+
local_dir="third_partys/Michelangelo"
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
file_path = hf_hub_download(
|
| 10 |
+
repo_id="Seed3D/MagicArticulate",
|
| 11 |
+
filename="skeleton_ckpt/checkpoint_trainonv2_hier.pth",
|
| 12 |
+
local_dir=""
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
file_path = hf_hub_download(
|
| 16 |
+
repo_id="Seed3D/MagicArticulate",
|
| 17 |
+
filename="skeleton_ckpt/checkpoint_trainonv2_spatial.pth",
|
| 18 |
+
local_dir=""
|
| 19 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#trimesh==4.2.3
|
| 2 |
+
#accelerate==0.28.0
|
| 3 |
+
#mesh2sdf==1.1.0
|
| 4 |
+
#transformers==4.39.3
|
| 5 |
+
#numpy==1.26.4
|
| 6 |
+
#pyrender==0.1.45
|
| 7 |
+
#tqdm
|
| 8 |
+
#opencv-python==4.9.0.80
|
| 9 |
+
#omegaconf==2.3.0
|
| 10 |
+
#einops==0.7.0
|
| 11 |
+
##======= HF===================
|
| 12 |
+
|
| 13 |
+
# MagicArticulate Requirements for Gradio Demo
|
| 14 |
+
# Compatible with CUDA 11.8 and Python 3.10
|
| 15 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 16 |
+
torch==2.1.1
|
| 17 |
+
torchvision==0.16.1
|
| 18 |
+
torchaudio==2.1.1
|
| 19 |
+
|
| 20 |
+
# Gradio for web interface
|
| 21 |
+
gradio==4.44.0
|
| 22 |
+
|
| 23 |
+
# 3D mesh processing
|
| 24 |
+
trimesh==4.4.3
|
| 25 |
+
accelerate==0.28.0
|
| 26 |
+
mesh2sdf==1.1.0
|
| 27 |
+
transformers==4.39.3
|
| 28 |
+
numpy==1.26.4
|
| 29 |
+
pyrender==0.1.45
|
| 30 |
+
tqdm
|
| 31 |
+
opencv-python==4.9.0.80
|
| 32 |
+
omegaconf==2.3.0
|
| 33 |
+
einops==0.7.0
|
| 34 |
+
|
| 35 |
+
flash-attn==2.6.3
|
| 36 |
+
huggingface_hub
|
| 37 |
+
gradio-client>=1.0.0
|
skeleton_models/shape_opt.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/buaacyw/MeshAnything
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoConfig, OPTConfig
|
| 3 |
+
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTLearnedPositionalEmbedding, OPTDecoderLayer
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
from transformers.modeling_outputs import (
|
| 6 |
+
CausalLMOutputWithPast,
|
| 7 |
+
)
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import CrossEntropyLoss
|
| 11 |
+
from transformers.utils import replace_return_docstrings
|
| 12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 13 |
+
|
| 14 |
+
class ShapeOPTConfig(OPTConfig):
|
| 15 |
+
model_type = "shape_opt"
|
| 16 |
+
|
| 17 |
+
class ShapeOPT(OPTForCausalLM):
|
| 18 |
+
config_class = ShapeOPTConfig
|
| 19 |
+
def __init__(self, config: ShapeOPTConfig):
|
| 20 |
+
super(OPTForCausalLM, self).__init__(config)
|
| 21 |
+
self.model = ShapeOPTModel(config)
|
| 22 |
+
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
| 23 |
+
# Initialize weights and apply final processing
|
| 24 |
+
self.post_init()
|
| 25 |
+
|
| 26 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="OPTConfig")
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
input_ids: torch.LongTensor = None,
|
| 30 |
+
bone_ids: torch.LongTensor = None,
|
| 31 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 32 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 33 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 34 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 35 |
+
labels: Optional[torch.LongTensor] = None,
|
| 36 |
+
use_cache: Optional[bool] = None,
|
| 37 |
+
output_attentions: Optional[bool] = None,
|
| 38 |
+
output_hidden_states: Optional[bool] = None,
|
| 39 |
+
return_dict: Optional[bool] = None,
|
| 40 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 41 |
+
r"""
|
| 42 |
+
Args:
|
| 43 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 44 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 45 |
+
provide it.
|
| 46 |
+
|
| 47 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 48 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 49 |
+
|
| 50 |
+
[What are input IDs?](../glossary#input-ids)
|
| 51 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 52 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 53 |
+
|
| 54 |
+
- 1 for tokens that are **not masked**,
|
| 55 |
+
- 0 for tokens that are **masked**.
|
| 56 |
+
|
| 57 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 58 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
| 59 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 60 |
+
|
| 61 |
+
- 1 indicates the head is **not masked**,
|
| 62 |
+
- 0 indicates the head is **masked**.
|
| 63 |
+
|
| 64 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 65 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 66 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
| 67 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
| 68 |
+
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
| 69 |
+
|
| 70 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
| 71 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 72 |
+
|
| 73 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
| 74 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
| 75 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 76 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 77 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 78 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 79 |
+
than the model's internal embedding lookup matrix.
|
| 80 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 81 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 82 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 83 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 84 |
+
use_cache (`bool`, *optional*):
|
| 85 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 86 |
+
(see `past_key_values`).
|
| 87 |
+
output_attentions (`bool`, *optional*):
|
| 88 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 89 |
+
returned tensors for more detail.
|
| 90 |
+
output_hidden_states (`bool`, *optional*):
|
| 91 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 92 |
+
for more detail.
|
| 93 |
+
return_dict (`bool`, *optional*):
|
| 94 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
>>> from transformers import AutoTokenizer, OPTForCausalLM
|
| 102 |
+
|
| 103 |
+
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
|
| 104 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
| 105 |
+
|
| 106 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 107 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 108 |
+
|
| 109 |
+
>>> # Generate
|
| 110 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 111 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 112 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
|
| 113 |
+
```"""
|
| 114 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 115 |
+
output_hidden_states = (
|
| 116 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 117 |
+
)
|
| 118 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 119 |
+
|
| 120 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 121 |
+
outputs = self.model.decoder(
|
| 122 |
+
input_ids = input_ids,
|
| 123 |
+
bone_ids = bone_ids,
|
| 124 |
+
attention_mask=attention_mask,
|
| 125 |
+
head_mask=head_mask,
|
| 126 |
+
past_key_values=past_key_values,
|
| 127 |
+
inputs_embeds=inputs_embeds,
|
| 128 |
+
use_cache=use_cache,
|
| 129 |
+
output_attentions=output_attentions,
|
| 130 |
+
output_hidden_states=output_hidden_states,
|
| 131 |
+
return_dict=return_dict,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
logits = self.lm_head(outputs[0]).contiguous()
|
| 135 |
+
|
| 136 |
+
loss = None
|
| 137 |
+
if labels is not None:
|
| 138 |
+
# move labels to correct device to enable model parallelism
|
| 139 |
+
labels = labels.to(logits.device)
|
| 140 |
+
# Shift so that tokens < n predict n
|
| 141 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 142 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 143 |
+
# Flatten the tokens
|
| 144 |
+
loss_fct = CrossEntropyLoss()
|
| 145 |
+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
| 146 |
+
|
| 147 |
+
if not return_dict:
|
| 148 |
+
output = (logits,) + outputs[1:]
|
| 149 |
+
return (loss,) + output if loss is not None else output
|
| 150 |
+
|
| 151 |
+
return CausalLMOutputWithPast(
|
| 152 |
+
loss=loss,
|
| 153 |
+
logits=logits,
|
| 154 |
+
past_key_values=outputs.past_key_values,
|
| 155 |
+
hidden_states=outputs.hidden_states,
|
| 156 |
+
attentions=outputs.attentions,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
class ShapeOPTModel(OPTModel):
|
| 160 |
+
config_class = ShapeOPTConfig
|
| 161 |
+
def __init__(self, config: ShapeOPTConfig):
|
| 162 |
+
super(OPTModel,self).__init__(config)
|
| 163 |
+
self.decoder = ShapeOPTDecoder(config)
|
| 164 |
+
# Initialize weights and apply final processing
|
| 165 |
+
self.post_init()
|
| 166 |
+
|
| 167 |
+
class ShapeOPTDecoder(OPTDecoder):
|
| 168 |
+
config_class = ShapeOPTConfig
|
| 169 |
+
def __init__(self, config: ShapeOPTConfig):
|
| 170 |
+
super(OPTDecoder,self).__init__(config)
|
| 171 |
+
self.config = config
|
| 172 |
+
self.dropout = config.dropout
|
| 173 |
+
self.layerdrop = config.layerdrop
|
| 174 |
+
self.padding_idx = config.pad_token_id
|
| 175 |
+
self.vocab_size = config.vocab_size
|
| 176 |
+
assert config.word_embed_proj_dim == config.hidden_size
|
| 177 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
|
| 178 |
+
self.hidden_size = config.hidden_size
|
| 179 |
+
self.word_embed_proj_dim = config.word_embed_proj_dim
|
| 180 |
+
self.n_discrete_size = config.n_discrete_size
|
| 181 |
+
|
| 182 |
+
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
| 183 |
+
self.token_embed_positions = OPTBonePositionalEmbedding(config.bone_per_token+3, config.word_embed_proj_dim)
|
| 184 |
+
|
| 185 |
+
self.bone_per_token = config.bone_per_token
|
| 186 |
+
self.cond_length = config.cond_length
|
| 187 |
+
self.cond_embed = nn.Embedding(2, config.word_embed_proj_dim)
|
| 188 |
+
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
| 189 |
+
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
| 190 |
+
# see https://github.com/facebookresearch/metaseq/pull/164
|
| 191 |
+
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
| 192 |
+
self.final_layer_norm = nn.LayerNorm(
|
| 193 |
+
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
self.final_layer_norm = None
|
| 197 |
+
|
| 198 |
+
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 199 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 200 |
+
|
| 201 |
+
self.gradient_checkpointing = False
|
| 202 |
+
# Initialize weights and apply final processing
|
| 203 |
+
self.post_init()
|
| 204 |
+
|
| 205 |
+
def forward(
|
| 206 |
+
self,
|
| 207 |
+
input_ids: torch.LongTensor = None,
|
| 208 |
+
bone_ids: torch.LongTensor = None,
|
| 209 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 210 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 211 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 212 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 213 |
+
use_cache: Optional[bool] = None,
|
| 214 |
+
output_attentions: Optional[bool] = None,
|
| 215 |
+
output_hidden_states: Optional[bool] = None,
|
| 216 |
+
return_dict: Optional[bool] = None,
|
| 217 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 218 |
+
r"""
|
| 219 |
+
Args:
|
| 220 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 221 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
| 222 |
+
provide it.
|
| 223 |
+
|
| 224 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 225 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 226 |
+
|
| 227 |
+
[What are input IDs?](../glossary#input-ids)
|
| 228 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 229 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 230 |
+
|
| 231 |
+
- 1 for tokens that are **not masked**,
|
| 232 |
+
- 0 for tokens that are **masked**.
|
| 233 |
+
|
| 234 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 235 |
+
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
| 236 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 237 |
+
|
| 238 |
+
- 1 indicates the head is **not masked**,
|
| 239 |
+
- 0 indicates the head is **masked**.
|
| 240 |
+
|
| 241 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 242 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 243 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
| 244 |
+
|
| 245 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
| 246 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
| 247 |
+
|
| 248 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
| 249 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
| 250 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
| 251 |
+
|
| 252 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 253 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 254 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 255 |
+
than the model's internal embedding lookup matrix.
|
| 256 |
+
output_attentions (`bool`, *optional*):
|
| 257 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 258 |
+
returned tensors for more detail.
|
| 259 |
+
output_hidden_states (`bool`, *optional*):
|
| 260 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 261 |
+
for more detail.
|
| 262 |
+
return_dict (`bool`, *optional*):
|
| 263 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 264 |
+
"""
|
| 265 |
+
# OPT Decoder
|
| 266 |
+
# print("used my Trans")
|
| 267 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 268 |
+
output_hidden_states = (
|
| 269 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 270 |
+
)
|
| 271 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 272 |
+
|
| 273 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 274 |
+
# Transformer Decoder
|
| 275 |
+
if input_ids is not None and inputs_embeds is not None: # when training
|
| 276 |
+
pass
|
| 277 |
+
elif input_ids is not None: # when inference
|
| 278 |
+
assert not self.training
|
| 279 |
+
input_shape = input_ids.size()
|
| 280 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 281 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 282 |
+
bone_embeds = self.token_embed_positions(attention_mask[:, self.cond_length:], bone_ids, input_ids,
|
| 283 |
+
self.bone_per_token)
|
| 284 |
+
inputs_embeds += bone_embeds
|
| 285 |
+
cond_embed_query = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=inputs_embeds.device,
|
| 286 |
+
dtype=inputs_embeds.dtype).long()
|
| 287 |
+
inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
|
| 288 |
+
|
| 289 |
+
elif inputs_embeds is not None: # when generate first skeleton token
|
| 290 |
+
assert not self.training
|
| 291 |
+
total_length = inputs_embeds.shape[1]
|
| 292 |
+
cond_embed_query = torch.zeros((inputs_embeds.shape[0], total_length), device=inputs_embeds.device,
|
| 293 |
+
dtype=inputs_embeds.dtype).long()
|
| 294 |
+
inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
|
| 295 |
+
else:
|
| 296 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 297 |
+
|
| 298 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 299 |
+
# embed positions
|
| 300 |
+
if self._use_flash_attention_2:
|
| 301 |
+
# 2d mask is passed through the layers
|
| 302 |
+
assert attention_mask is not None
|
| 303 |
+
causal_attention_mask = attention_mask if 0 in attention_mask else None
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError("Only flash_attention_2 is supported")
|
| 306 |
+
|
| 307 |
+
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
| 308 |
+
|
| 309 |
+
hidden_states = inputs_embeds + pos_embeds
|
| 310 |
+
|
| 311 |
+
# decoder layers
|
| 312 |
+
all_hidden_states = () if output_hidden_states else None
|
| 313 |
+
all_self_attns = () if output_attentions else None
|
| 314 |
+
next_decoder_cache = () if use_cache else None
|
| 315 |
+
|
| 316 |
+
# check if head_mask has a correct number of layers specified if desired
|
| 317 |
+
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
| 318 |
+
if attn_mask is not None:
|
| 319 |
+
if attn_mask.size()[0] != (len(self.layers)):
|
| 320 |
+
raise ValueError(
|
| 321 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
| 322 |
+
f" {head_mask.size()[0]}."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 326 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 327 |
+
if output_hidden_states:
|
| 328 |
+
all_hidden_states += (hidden_states,)
|
| 329 |
+
|
| 330 |
+
if self.training:
|
| 331 |
+
dropout_probability = torch.rand([])
|
| 332 |
+
if dropout_probability < self.layerdrop:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 336 |
+
|
| 337 |
+
if self.gradient_checkpointing and self.training:
|
| 338 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 339 |
+
decoder_layer.__call__,
|
| 340 |
+
hidden_states,
|
| 341 |
+
causal_attention_mask,
|
| 342 |
+
head_mask[idx] if head_mask is not None else None,
|
| 343 |
+
None,
|
| 344 |
+
output_attentions,
|
| 345 |
+
use_cache,
|
| 346 |
+
)
|
| 347 |
+
else:
|
| 348 |
+
layer_outputs = decoder_layer(
|
| 349 |
+
hidden_states,
|
| 350 |
+
attention_mask=causal_attention_mask,
|
| 351 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
| 352 |
+
past_key_value=past_key_value,
|
| 353 |
+
output_attentions=output_attentions,
|
| 354 |
+
use_cache=use_cache,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
hidden_states = layer_outputs[0]
|
| 358 |
+
|
| 359 |
+
if use_cache:
|
| 360 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 361 |
+
|
| 362 |
+
if output_attentions:
|
| 363 |
+
all_self_attns += (layer_outputs[1],)
|
| 364 |
+
|
| 365 |
+
if self.final_layer_norm is not None:
|
| 366 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 367 |
+
|
| 368 |
+
# add hidden states from the last decoder layer
|
| 369 |
+
if output_hidden_states:
|
| 370 |
+
all_hidden_states += (hidden_states,)
|
| 371 |
+
|
| 372 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 373 |
+
if not return_dict:
|
| 374 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 375 |
+
return BaseModelOutputWithPast(
|
| 376 |
+
last_hidden_state=hidden_states,
|
| 377 |
+
past_key_values=next_cache,
|
| 378 |
+
hidden_states=all_hidden_states,
|
| 379 |
+
attentions=all_self_attns,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
class OPTBonePositionalEmbedding(nn.Embedding):
|
| 383 |
+
"""
|
| 384 |
+
This module learns positional embeddings up to a fixed maximum size.
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
| 388 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 389 |
+
|
| 390 |
+
def forward(self, attention_mask=None, bone_ids = None, input_ids = None, bone_per_token = None):
|
| 391 |
+
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
| 392 |
+
if bone_ids is not None:
|
| 393 |
+
return super().forward(bone_ids)
|
| 394 |
+
|
| 395 |
+
assert input_ids.shape[1] == 1
|
| 396 |
+
idx_in_extra = torch.isin(input_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
|
| 397 |
+
cur_ids = input_ids.clone().detach()
|
| 398 |
+
|
| 399 |
+
cur_index = (attention_mask.sum(dim=1, keepdim=True) - 2) % bone_per_token + 3
|
| 400 |
+
cur_ids[~idx_in_extra]=cur_index[~idx_in_extra]
|
| 401 |
+
|
| 402 |
+
return super().forward(cur_ids)
|
| 403 |
+
|
| 404 |
+
AutoConfig.register("shape_opt", ShapeOPTConfig)
|
| 405 |
+
AutoModelForCausalLM.register(ShapeOPTConfig, ShapeOPT)
|
| 406 |
+
|
skeleton_models/skeletongen.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
from transformers import AutoModelForCausalLM
|
| 17 |
+
from third_partys.Michelangelo.encode import load_model
|
| 18 |
+
from skeleton_models.shape_opt import ShapeOPTConfig
|
| 19 |
+
|
| 20 |
+
def undiscretize(t, low, high, num_discrete):
|
| 21 |
+
assert (t >= 0).all() and (t <= num_discrete-1).all()
|
| 22 |
+
assert high > low
|
| 23 |
+
t = t.float()
|
| 24 |
+
t /= num_discrete
|
| 25 |
+
t = t * (high - low) + low
|
| 26 |
+
assert (t < high).all() and (t >= low).all()
|
| 27 |
+
return t
|
| 28 |
+
|
| 29 |
+
class SkeletonGPT(nn.Module):
|
| 30 |
+
def __init__(self, args):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.args = args
|
| 34 |
+
self.point_encoder = load_model()
|
| 35 |
+
|
| 36 |
+
self.cond_length = 257
|
| 37 |
+
self.cond_dim = 768
|
| 38 |
+
|
| 39 |
+
self.n_discrete_size = args.n_discrete_size
|
| 40 |
+
|
| 41 |
+
self.bone_per_token = 6 # (2 joints per bone)
|
| 42 |
+
self.max_length = int(args.n_max_bones * self.bone_per_token + 2 + self.cond_length)
|
| 43 |
+
self.pad_id = -1
|
| 44 |
+
|
| 45 |
+
self.coor_continuous_range = (-0.5, 0.5)
|
| 46 |
+
|
| 47 |
+
vocab_size = self.n_discrete_size + 3 # 3 for bos, eos, pad
|
| 48 |
+
self.config = ShapeOPTConfig.from_pretrained(
|
| 49 |
+
args.llm,
|
| 50 |
+
n_positions=self.max_length,
|
| 51 |
+
max_position_embeddings=self.max_length,
|
| 52 |
+
vocab_size = vocab_size,
|
| 53 |
+
_attn_implementation="flash_attention_2"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.bos_token_id = 0
|
| 57 |
+
self.eos_token_id = 1
|
| 58 |
+
self.pad_token_id = 2
|
| 59 |
+
|
| 60 |
+
self.config.bos_token_id = self.bos_token_id
|
| 61 |
+
self.config.eos_token_id = self.eos_token_id
|
| 62 |
+
self.config.pad_token_id = self.pad_token_id
|
| 63 |
+
self.config._attn_implementation ="flash_attention_2"
|
| 64 |
+
self.config.n_discrete_size = self.n_discrete_size
|
| 65 |
+
self.config.bone_per_token = self.bone_per_token
|
| 66 |
+
self.config.cond_length = self.cond_length
|
| 67 |
+
|
| 68 |
+
self.config.word_embed_proj_dim = self.config.hidden_size # 1024
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
self.transformer = AutoModelForCausalLM.from_config(
|
| 72 |
+
config=self.config, attn_implementation="flash_attention_2")
|
| 73 |
+
|
| 74 |
+
self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
|
| 75 |
+
self.cond_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
|
| 76 |
+
|
| 77 |
+
self.eval()
|
| 78 |
+
|
| 79 |
+
def detokenize(self, input_ids):
|
| 80 |
+
# input_ids: torch.Tensor of shape (batch_size, seq_length)
|
| 81 |
+
batch_size = input_ids.size(0)
|
| 82 |
+
|
| 83 |
+
continuous_coors_list = []
|
| 84 |
+
num_bones_list = []
|
| 85 |
+
|
| 86 |
+
for i in range(batch_size):
|
| 87 |
+
cur_ids = input_ids[i] # Shape: (seq_length,)
|
| 88 |
+
|
| 89 |
+
# Remove padding tokens
|
| 90 |
+
cur_ids = cur_ids[cur_ids != self.pad_id] # Shape: (effective_seq_length,)
|
| 91 |
+
|
| 92 |
+
# Check if length is a multiple of 6 (2 joints * 3 coordinates)
|
| 93 |
+
if cur_ids.numel() % 6 != 0:
|
| 94 |
+
return None
|
| 95 |
+
# raise ValueError(f"Invalid length of input_ids in sample {i}. It should be a multiple of 6.")
|
| 96 |
+
|
| 97 |
+
num_bones = cur_ids.numel() // 6
|
| 98 |
+
num_bones_list.append(num_bones)
|
| 99 |
+
|
| 100 |
+
# Reshape into (num_bones, 6)
|
| 101 |
+
bone_coords = cur_ids.view(num_bones, 6) # Shape: (num_bones, 6)
|
| 102 |
+
|
| 103 |
+
# Undiscretize the coordinates
|
| 104 |
+
# Initialize tensor to hold bone coordinates
|
| 105 |
+
bones_coors = torch.zeros((num_bones, 2, 3), dtype=torch.float16, device=cur_ids.device)
|
| 106 |
+
|
| 107 |
+
for j in range(num_bones):
|
| 108 |
+
bone_coord = bone_coords[j] # Shape: (6,)
|
| 109 |
+
|
| 110 |
+
# Split into two joints
|
| 111 |
+
joint1_ids = bone_coord[:3]
|
| 112 |
+
joint2_ids = bone_coord[3:]
|
| 113 |
+
|
| 114 |
+
# Undiscretize joint coordinates
|
| 115 |
+
joint1_coords = undiscretize(joint1_ids, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
|
| 116 |
+
joint2_coords = undiscretize(joint2_ids, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
|
| 117 |
+
|
| 118 |
+
# Assign to bones_coors
|
| 119 |
+
bones_coors[j, 0, :] = joint1_coords
|
| 120 |
+
bones_coors[j, 1, :] = joint2_coords
|
| 121 |
+
|
| 122 |
+
continuous_coors_list.append(bones_coors)
|
| 123 |
+
|
| 124 |
+
max_num_bones = max(num_bones_list)
|
| 125 |
+
|
| 126 |
+
# Initialize the continuous_coors tensor with NaNs
|
| 127 |
+
continuous_coors = torch.full(
|
| 128 |
+
(batch_size, max_num_bones, 2, 3),
|
| 129 |
+
float('nan'),
|
| 130 |
+
dtype=torch.float16,
|
| 131 |
+
device=input_ids.device
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Place the bones_coors into continuous_coors
|
| 135 |
+
for i in range(batch_size):
|
| 136 |
+
num_bones = num_bones_list[i]
|
| 137 |
+
continuous_coors[i, :num_bones, :, :] = continuous_coors_list[i]
|
| 138 |
+
|
| 139 |
+
return continuous_coors # Shape: (batch_size, max_num_bones, 2, 3)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# def forward(self, data_dict: dict, is_eval: bool = False) -> dict:
|
| 143 |
+
# return self.generate(data_dict)
|
| 144 |
+
|
| 145 |
+
def process_point_feature(self, point_feature):
|
| 146 |
+
|
| 147 |
+
encode_feature = torch.zeros(self.args.batchsize_per_gpu, self.cond_length, self.config.word_embed_proj_dim,
|
| 148 |
+
device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
|
| 149 |
+
encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
|
| 150 |
+
shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
|
| 151 |
+
|
| 152 |
+
encode_feature[:, 1:] = self.cond_proj(shape_latents)
|
| 153 |
+
|
| 154 |
+
return encode_feature
|
| 155 |
+
|
| 156 |
+
@torch.no_grad()
|
| 157 |
+
def generate(self, data_dict) -> dict:
|
| 158 |
+
|
| 159 |
+
point_feature = self.point_encoder.encode_latents(data_dict["pc_normal"])
|
| 160 |
+
processed_point_feature = self.process_point_feature(point_feature=point_feature)
|
| 161 |
+
generate_length = self.max_length - self.cond_length
|
| 162 |
+
net_device = next(self.parameters()).device
|
| 163 |
+
outputs = torch.ones(self.args.batchsize_per_gpu, generate_length).long().to(net_device) * self.eos_token_id
|
| 164 |
+
# batch x ntokens
|
| 165 |
+
if self.args.num_beams is not None and "pc_normal" in data_dict:
|
| 166 |
+
results = self.transformer.generate(
|
| 167 |
+
inputs_embeds=processed_point_feature,
|
| 168 |
+
max_new_tokens=generate_length, # all faces plus two
|
| 169 |
+
num_beams=self.args.num_beams,
|
| 170 |
+
bos_token_id=self.bos_token_id,
|
| 171 |
+
eos_token_id=self.eos_token_id,
|
| 172 |
+
pad_token_id=self.pad_token_id,
|
| 173 |
+
)
|
| 174 |
+
else:
|
| 175 |
+
results = self.transformer.generate(
|
| 176 |
+
inputs_embeds = processed_point_feature,
|
| 177 |
+
max_new_tokens = generate_length, # all faces plus two
|
| 178 |
+
do_sample=True,
|
| 179 |
+
top_k=50,
|
| 180 |
+
top_p=0.95,
|
| 181 |
+
bos_token_id = self.bos_token_id,
|
| 182 |
+
eos_token_id = self.eos_token_id,
|
| 183 |
+
pad_token_id = self.pad_token_id,
|
| 184 |
+
)
|
| 185 |
+
assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted
|
| 186 |
+
outputs[:, :results.shape[1]] = results
|
| 187 |
+
# batch x ntokens ====> batch x ntokens x D
|
| 188 |
+
outputs = outputs[:, 1: -1] # eos and bos removed
|
| 189 |
+
|
| 190 |
+
outputs[outputs == self.bos_token_id] = self.pad_id
|
| 191 |
+
outputs[outputs == self.eos_token_id] = self.pad_id
|
| 192 |
+
outputs[outputs == self.pad_token_id] = self.pad_id
|
| 193 |
+
|
| 194 |
+
outputs[outputs != self.pad_id] -= 3
|
| 195 |
+
|
| 196 |
+
gen_joints = self.detokenize(outputs)
|
| 197 |
+
|
| 198 |
+
return gen_joints
|
utils/eval_utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/zhan-xu/RigNet
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
##### for quantitative calculation
|
| 6 |
+
def chamfer_dist(pt1, pt2):
|
| 7 |
+
pt1 = pt1[np.newaxis, :, :]
|
| 8 |
+
pt2 = pt2[:, np.newaxis, :]
|
| 9 |
+
dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
|
| 10 |
+
min_left = np.mean(np.min(dist, axis=0))
|
| 11 |
+
min_right = np.mean(np.min(dist, axis=1))
|
| 12 |
+
return (min_left + min_right) / 2
|
| 13 |
+
|
| 14 |
+
def oneway_chamfer(pt_src, pt_dst):
|
| 15 |
+
pt1 = pt_src[np.newaxis, :, :]
|
| 16 |
+
pt2 = pt_dst[:, np.newaxis, :]
|
| 17 |
+
dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
|
| 18 |
+
avg_dist = np.mean(np.min(dist, axis=0))
|
| 19 |
+
return avg_dist
|
| 20 |
+
|
| 21 |
+
def joint2bone_chamfer_dist(joints1, bones1, joints2, bones2):
|
| 22 |
+
bone_sample_1 = sample_skel(joints1, bones1)
|
| 23 |
+
bone_sample_2 = sample_skel(joints2, bones2)
|
| 24 |
+
dist1 = oneway_chamfer(joints1, bone_sample_2)
|
| 25 |
+
dist2 = oneway_chamfer(joints2, bone_sample_1)
|
| 26 |
+
return (dist1 + dist2) / 2
|
| 27 |
+
|
| 28 |
+
def bone2bone_chamfer_dist(joints1, bones1, joints2, bones2):
|
| 29 |
+
bone_sample_1 = sample_skel(joints1, bones1)
|
| 30 |
+
bone_sample_2 = sample_skel(joints2, bones2)
|
| 31 |
+
return chamfer_dist(bone_sample_1, bone_sample_2)
|
| 32 |
+
|
| 33 |
+
def sample_bone(p_pos, ch_pos):
|
| 34 |
+
ray = ch_pos - p_pos
|
| 35 |
+
|
| 36 |
+
bone_length = np.linalg.norm(p_pos - ch_pos)
|
| 37 |
+
num_step = np.round(bone_length / 0.005).astype(int)
|
| 38 |
+
i_step = np.arange(0, num_step + 1)
|
| 39 |
+
unit_step = ray / (num_step + 1e-30)
|
| 40 |
+
unit_step = np.repeat(unit_step[np.newaxis, :], num_step + 1, axis=0)
|
| 41 |
+
res = p_pos + unit_step * i_step[:, np.newaxis]
|
| 42 |
+
return res
|
| 43 |
+
|
| 44 |
+
def sample_skel(joints, bones):
|
| 45 |
+
bone_sample = []
|
| 46 |
+
for parent_idx, child_idx in bones:
|
| 47 |
+
p_pos = joints[parent_idx]
|
| 48 |
+
ch_pos = joints[child_idx]
|
| 49 |
+
res = sample_bone(p_pos, ch_pos)
|
| 50 |
+
bone_sample.append(res)
|
| 51 |
+
|
| 52 |
+
if bone_sample:
|
| 53 |
+
bone_sample = np.concatenate(bone_sample, axis=0)
|
| 54 |
+
else:
|
| 55 |
+
bone_sample = np.empty((0, 3))
|
| 56 |
+
|
| 57 |
+
return bone_sample
|
utils/mesh_to_pc.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/buaacyw/MeshAnything
|
| 2 |
+
import mesh2sdf.core
|
| 3 |
+
import numpy as np
|
| 4 |
+
import skimage.measure
|
| 5 |
+
import trimesh
|
| 6 |
+
import time
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
class MeshProcessor:
|
| 10 |
+
"""A class to handle mesh normalization, watertight conversion and point cloud sampling."""
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def normalize_mesh_vertices(vertices: np.ndarray, scaling_factor: float = 0.95) -> Tuple[np.ndarray, np.ndarray, float]:
|
| 14 |
+
"""
|
| 15 |
+
Normalize mesh vertices to be centered at origin and scaled appropriately.
|
| 16 |
+
"""
|
| 17 |
+
min_bounds = vertices.min(axis=0)
|
| 18 |
+
max_bounds = vertices.max(axis=0)
|
| 19 |
+
|
| 20 |
+
center = (min_bounds + max_bounds) * 0.5
|
| 21 |
+
max_dimension = (max_bounds - min_bounds).max()
|
| 22 |
+
scale = 2.0 * scaling_factor / max_dimension
|
| 23 |
+
|
| 24 |
+
normalized_vertices = (vertices - center) * scale
|
| 25 |
+
return normalized_vertices, center, scale
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def convert_to_watertight(mesh: trimesh.Trimesh, octree_depth: int = 7) -> trimesh.Trimesh:
|
| 29 |
+
"""
|
| 30 |
+
Convert to watertight using mesh2sdf and marching cubes.
|
| 31 |
+
"""
|
| 32 |
+
grid_size = 2 ** octree_depth
|
| 33 |
+
iso_level = 2 / grid_size
|
| 34 |
+
|
| 35 |
+
# Normalize vertices for SDF computation
|
| 36 |
+
normalized_vertices, original_center, original_scale = MeshProcessor.normalize_mesh_vertices(mesh.vertices)
|
| 37 |
+
|
| 38 |
+
# Compute signed distance field
|
| 39 |
+
sdf = mesh2sdf.core.compute(normalized_vertices, mesh.faces, size=grid_size)
|
| 40 |
+
|
| 41 |
+
# Run marching cubes algorithm
|
| 42 |
+
vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), iso_level)
|
| 43 |
+
|
| 44 |
+
# Transform vertices back to original coordinate system
|
| 45 |
+
vertices = vertices / grid_size * 2 - 1 # Map to [-1, 1] range
|
| 46 |
+
vertices = vertices / original_scale + original_center
|
| 47 |
+
|
| 48 |
+
# Create new watertight mesh
|
| 49 |
+
watertight_mesh = trimesh.Trimesh(vertices, faces, normals=normals)
|
| 50 |
+
return watertight_mesh
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def convert_meshes_to_point_clouds(
|
| 54 |
+
meshes: List[trimesh.Trimesh],
|
| 55 |
+
points_per_mesh: int = 8192,
|
| 56 |
+
apply_marching_cubes: bool = False,
|
| 57 |
+
octree_depth: int = 7
|
| 58 |
+
) -> List[np.ndarray]:
|
| 59 |
+
"""
|
| 60 |
+
Process a list of meshes into point clouds with normals.
|
| 61 |
+
"""
|
| 62 |
+
point_clouds_with_normals = []
|
| 63 |
+
processed_meshes = []
|
| 64 |
+
|
| 65 |
+
for mesh in meshes:
|
| 66 |
+
# Optionally convert to watertight mesh
|
| 67 |
+
if apply_marching_cubes:
|
| 68 |
+
start_time = time.time()
|
| 69 |
+
mesh = MeshProcessor.convert_to_watertight(mesh, octree_depth=octree_depth)
|
| 70 |
+
processing_time = time.time() - start_time
|
| 71 |
+
print(f"Marching cubes complete! Time: {processing_time:.2f}s")
|
| 72 |
+
|
| 73 |
+
# Store processed mesh
|
| 74 |
+
processed_meshes.append(mesh)
|
| 75 |
+
|
| 76 |
+
# Sample points and get corresponding face normals
|
| 77 |
+
points, face_indices = mesh.sample(points_per_mesh, return_index=True)
|
| 78 |
+
point_normals = mesh.face_normals[face_indices]
|
| 79 |
+
|
| 80 |
+
# Combine points and normals
|
| 81 |
+
points_with_normals = np.concatenate([points, point_normals], axis=-1, dtype=np.float16)
|
| 82 |
+
point_clouds_with_normals.append(points_with_normals)
|
| 83 |
+
|
| 84 |
+
return point_clouds_with_normals
|
utils/save_utils.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import numpy as np
|
| 16 |
+
import cv2
|
| 17 |
+
import json
|
| 18 |
+
import trimesh
|
| 19 |
+
|
| 20 |
+
from collections import deque, defaultdict
|
| 21 |
+
from scipy.cluster.hierarchy import linkage, fcluster
|
| 22 |
+
from scipy.spatial.distance import cdist
|
| 23 |
+
|
| 24 |
+
from data_utils.pyrender_wrapper import PyRenderWrapper
|
| 25 |
+
from data_utils.data_loader import DataLoader
|
| 26 |
+
|
| 27 |
+
def save_mesh(vertices, faces, filename):
|
| 28 |
+
|
| 29 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 30 |
+
mesh.export(filename, file_type='obj')
|
| 31 |
+
|
| 32 |
+
def pred_joints_and_bones(bone_coor):
|
| 33 |
+
"""
|
| 34 |
+
get joints (j,3) and bones (b,2) from (b,2,3), preserve the parent-child relationship
|
| 35 |
+
"""
|
| 36 |
+
parent_coords = bone_coor[:, 0, :] # (b, 3)
|
| 37 |
+
child_coords = bone_coor[:, 1, :] # (b, 3)
|
| 38 |
+
|
| 39 |
+
all_coords = np.vstack([parent_coords, child_coords]) # (2b, 3)
|
| 40 |
+
pred_joints, indices = np.unique(all_coords, axis=0, return_inverse=True)
|
| 41 |
+
|
| 42 |
+
b = bone_coor.shape[0]
|
| 43 |
+
parent_indices = indices[:b]
|
| 44 |
+
child_indices = indices[b:]
|
| 45 |
+
|
| 46 |
+
pred_bones = np.column_stack([parent_indices, child_indices])
|
| 47 |
+
|
| 48 |
+
valid_bones = pred_bones[parent_indices != child_indices]
|
| 49 |
+
|
| 50 |
+
return pred_joints, valid_bones
|
| 51 |
+
|
| 52 |
+
def find_connected_components(joints, bones):
|
| 53 |
+
"""Find connected components in the skeleton graph."""
|
| 54 |
+
n_joints = len(joints)
|
| 55 |
+
graph = defaultdict(list)
|
| 56 |
+
|
| 57 |
+
# Build adjacency list
|
| 58 |
+
for parent, child in bones:
|
| 59 |
+
graph[parent].append(child)
|
| 60 |
+
graph[child].append(parent)
|
| 61 |
+
|
| 62 |
+
visited = [False] * n_joints
|
| 63 |
+
components = []
|
| 64 |
+
|
| 65 |
+
for i in range(n_joints):
|
| 66 |
+
if not visited[i]:
|
| 67 |
+
component = []
|
| 68 |
+
queue = deque([i])
|
| 69 |
+
visited[i] = True
|
| 70 |
+
|
| 71 |
+
while queue:
|
| 72 |
+
node = queue.popleft()
|
| 73 |
+
component.append(node)
|
| 74 |
+
|
| 75 |
+
for neighbor in graph[node]:
|
| 76 |
+
if not visited[neighbor]:
|
| 77 |
+
visited[neighbor] = True
|
| 78 |
+
queue.append(neighbor)
|
| 79 |
+
|
| 80 |
+
components.append(component)
|
| 81 |
+
|
| 82 |
+
return components
|
| 83 |
+
|
| 84 |
+
def ensure_skeleton_connectivity(joints, bones, root_index=None, merge_distance_threshold=0.01):
|
| 85 |
+
"""
|
| 86 |
+
Ensure skeleton is fully connected.
|
| 87 |
+
- If distance < merge_distance_threshold: merge joints
|
| 88 |
+
- If distance >= merge_distance_threshold: connect with bone
|
| 89 |
+
"""
|
| 90 |
+
current_joints = joints.copy()
|
| 91 |
+
current_bones = list(bones)
|
| 92 |
+
current_root = root_index
|
| 93 |
+
|
| 94 |
+
iteration = 0
|
| 95 |
+
while True:
|
| 96 |
+
components = find_connected_components(current_joints, current_bones)
|
| 97 |
+
if len(components) == 1:
|
| 98 |
+
# print("Successfully ensured skeleton connectivity")
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
# Find the globally closest pair of components
|
| 102 |
+
min_distance = float('inf')
|
| 103 |
+
best_pair = None
|
| 104 |
+
|
| 105 |
+
for i in range(len(components)):
|
| 106 |
+
for j in range(i + 1, len(components)):
|
| 107 |
+
comp1_joints = current_joints[components[i]]
|
| 108 |
+
comp2_joints = current_joints[components[j]]
|
| 109 |
+
|
| 110 |
+
distances = cdist(comp1_joints, comp2_joints)
|
| 111 |
+
min_idx = np.unravel_index(np.argmin(distances), distances.shape)
|
| 112 |
+
distance = distances[min_idx]
|
| 113 |
+
|
| 114 |
+
if distance < min_distance:
|
| 115 |
+
min_distance = distance
|
| 116 |
+
best_pair = (i, j, components[i][min_idx[0]], components[j][min_idx[1]], min_idx)
|
| 117 |
+
|
| 118 |
+
if best_pair is None:
|
| 119 |
+
print("Warning: Could not find valid component pair to connect")
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
comp1_idx, comp2_idx, joint1_idx, joint2_idx, min_idx = best_pair
|
| 123 |
+
|
| 124 |
+
if min_distance < merge_distance_threshold:
|
| 125 |
+
# Merge the joints
|
| 126 |
+
# print(f"Iteration {iteration + 1}: Merging closest joints {joint1_idx} and {joint2_idx} "
|
| 127 |
+
# f"(distance: {min_distance:.4f})")
|
| 128 |
+
|
| 129 |
+
# Always merge joint2 into joint1
|
| 130 |
+
merge_map = {joint2_idx: joint1_idx}
|
| 131 |
+
|
| 132 |
+
# Update bones
|
| 133 |
+
updated_bones = []
|
| 134 |
+
for parent, child in current_bones:
|
| 135 |
+
new_parent = merge_map.get(parent, parent)
|
| 136 |
+
new_child = merge_map.get(child, child)
|
| 137 |
+
if new_parent != new_child: # Remove self-loops
|
| 138 |
+
updated_bones.append([new_parent, new_child])
|
| 139 |
+
|
| 140 |
+
# Update root
|
| 141 |
+
if current_root == joint2_idx:
|
| 142 |
+
current_root = joint1_idx
|
| 143 |
+
|
| 144 |
+
# Remove the merged joint and update indices
|
| 145 |
+
joint_to_remove = joint2_idx
|
| 146 |
+
mask = np.ones(len(current_joints), dtype=bool)
|
| 147 |
+
mask[joint_to_remove] = False
|
| 148 |
+
current_joints = current_joints[mask]
|
| 149 |
+
|
| 150 |
+
# Create index mapping for remaining joints
|
| 151 |
+
old_to_new = {}
|
| 152 |
+
new_idx = 0
|
| 153 |
+
for old_idx in range(len(mask)):
|
| 154 |
+
if mask[old_idx]:
|
| 155 |
+
old_to_new[old_idx] = new_idx
|
| 156 |
+
new_idx += 1
|
| 157 |
+
|
| 158 |
+
# Update bone indices
|
| 159 |
+
current_bones = [[old_to_new[parent], old_to_new[child]]
|
| 160 |
+
for parent, child in updated_bones
|
| 161 |
+
if parent in old_to_new and child in old_to_new]
|
| 162 |
+
|
| 163 |
+
# Update root index
|
| 164 |
+
if current_root is not None and current_root in old_to_new:
|
| 165 |
+
current_root = old_to_new[current_root]
|
| 166 |
+
|
| 167 |
+
else:
|
| 168 |
+
# Connect with bone
|
| 169 |
+
# print(f"Iteration {iteration + 1}: Connecting closest components with bone {joint1_idx} -> {joint2_idx} "
|
| 170 |
+
# f"(distance: {min_distance:.4f})")
|
| 171 |
+
current_bones.append([joint1_idx, joint2_idx])
|
| 172 |
+
|
| 173 |
+
iteration += 1
|
| 174 |
+
|
| 175 |
+
# prevent infinite loops
|
| 176 |
+
if iteration > len(joints):
|
| 177 |
+
print(f"Warning: Maximum iterations reached ({iteration}), stopping")
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
current_bones = np.array(current_bones) if len(current_bones) > 0 else np.array([]).reshape(0, 2)
|
| 181 |
+
|
| 182 |
+
# Final connectivity verification
|
| 183 |
+
final_components = find_connected_components(current_joints, current_bones)
|
| 184 |
+
if len(final_components) == 1:
|
| 185 |
+
pass
|
| 186 |
+
else:
|
| 187 |
+
print(f"Warning: Still have {len(final_components)} disconnected components after {iteration} iterations")
|
| 188 |
+
|
| 189 |
+
return current_joints, current_bones, current_root
|
| 190 |
+
|
| 191 |
+
def merge_duplicate_joints_and_fix_bones(joints, bones, tolerance=0.0025, root_index=None):
|
| 192 |
+
"""
|
| 193 |
+
merge duplicate joints that are within a certain tolerance distance, and fix bones to maintain connectivity.
|
| 194 |
+
Also merge bones that become duplicates after joint merging.
|
| 195 |
+
"""
|
| 196 |
+
n_joints = len(joints)
|
| 197 |
+
|
| 198 |
+
# find merge joint groups
|
| 199 |
+
merge_groups = []
|
| 200 |
+
used = [False] * n_joints
|
| 201 |
+
|
| 202 |
+
for i in range(n_joints):
|
| 203 |
+
if used[i]:
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
# find all joints within tolerance distance to joint i
|
| 207 |
+
group = [i]
|
| 208 |
+
for j in range(i + 1, n_joints):
|
| 209 |
+
if not used[j]:
|
| 210 |
+
dist = np.linalg.norm(joints[i] - joints[j])
|
| 211 |
+
if dist < tolerance:
|
| 212 |
+
group.append(j)
|
| 213 |
+
used[j] = True
|
| 214 |
+
|
| 215 |
+
used[i] = True
|
| 216 |
+
merge_groups.append(group)
|
| 217 |
+
|
| 218 |
+
# if len(group) > 1:
|
| 219 |
+
# print(f"find duplicate joints group: {group}")
|
| 220 |
+
|
| 221 |
+
# build merge map: choose representative joint
|
| 222 |
+
merge_map = {}
|
| 223 |
+
for group in merge_groups:
|
| 224 |
+
if root_index is not None and root_index in group:
|
| 225 |
+
representative = root_index
|
| 226 |
+
else:
|
| 227 |
+
representative = group[0] # else choose the first one as representative
|
| 228 |
+
for joint_idx in group:
|
| 229 |
+
merge_map[joint_idx] = representative
|
| 230 |
+
|
| 231 |
+
# track root joint change
|
| 232 |
+
intermediate_root_index = None
|
| 233 |
+
if root_index is not None:
|
| 234 |
+
intermediate_root_index = merge_map.get(root_index, root_index)
|
| 235 |
+
# if intermediate_root_index != root_index:
|
| 236 |
+
# print(f"root joint index changed from {root_index} to {intermediate_root_index}")
|
| 237 |
+
|
| 238 |
+
# update bones: remove self-loop bones, and merge duplicate bones
|
| 239 |
+
updated_bones = []
|
| 240 |
+
|
| 241 |
+
for parent, child in bones:
|
| 242 |
+
new_parent = merge_map.get(parent, parent)
|
| 243 |
+
new_child = merge_map.get(child, child)
|
| 244 |
+
|
| 245 |
+
if new_parent != new_child: # remove self-loop bones
|
| 246 |
+
updated_bones.append([new_parent, new_child])
|
| 247 |
+
|
| 248 |
+
# remove duplicate bones
|
| 249 |
+
unique_bones = []
|
| 250 |
+
seen_bones = set()
|
| 251 |
+
|
| 252 |
+
for bone in updated_bones:
|
| 253 |
+
bone_key = tuple(bone) # keep the order of [parent, child]
|
| 254 |
+
if bone_key not in seen_bones:
|
| 255 |
+
seen_bones.add(bone_key)
|
| 256 |
+
unique_bones.append(bone)
|
| 257 |
+
|
| 258 |
+
# re-index joints to remove unused joints
|
| 259 |
+
used_joint_indices = set()
|
| 260 |
+
for parent, child in unique_bones:
|
| 261 |
+
used_joint_indices.add(parent)
|
| 262 |
+
used_joint_indices.add(child)
|
| 263 |
+
if intermediate_root_index is not None:
|
| 264 |
+
used_joint_indices.add(intermediate_root_index)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
used_joint_indices = sorted(list(used_joint_indices))
|
| 268 |
+
|
| 269 |
+
# new index for used joints
|
| 270 |
+
old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(used_joint_indices)}
|
| 271 |
+
|
| 272 |
+
final_joints = joints[used_joint_indices]
|
| 273 |
+
final_bones = np.array([[old_to_new[parent], old_to_new[child]]
|
| 274 |
+
for parent, child in unique_bones])
|
| 275 |
+
|
| 276 |
+
final_root_index = None
|
| 277 |
+
if intermediate_root_index is not None:
|
| 278 |
+
final_root_index = old_to_new[intermediate_root_index]
|
| 279 |
+
if root_index is not None and final_root_index != root_index:
|
| 280 |
+
print(f"final root index: {root_index} -> {final_root_index}")
|
| 281 |
+
|
| 282 |
+
removed_joints = n_joints - len(final_joints)
|
| 283 |
+
removed_bones = len(bones) - len(final_bones)
|
| 284 |
+
|
| 285 |
+
# print
|
| 286 |
+
# if removed_joints > 0 or removed_bones > 0:
|
| 287 |
+
# print(f"merge results:")
|
| 288 |
+
# print(f" joint number: {n_joints} -> {len(final_joints)} (remove {removed_joints})")
|
| 289 |
+
# print(f" bone number: {len(bones)} -> {len(final_bones)} (remove {removed_bones})")
|
| 290 |
+
|
| 291 |
+
# Ensure skeleton connectivity with relaxed threshold
|
| 292 |
+
final_joints, final_bones, final_root_index = ensure_skeleton_connectivity(
|
| 293 |
+
final_joints, final_bones, final_root_index,
|
| 294 |
+
merge_distance_threshold=tolerance*8 # More relaxed threshold for connectivity
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if root_index is not None:
|
| 298 |
+
return final_joints, final_bones, final_root_index
|
| 299 |
+
else:
|
| 300 |
+
return final_joints, final_bones
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def save_skeleton_to_txt(pred_joints, pred_bones, pred_root_index, hier_order, vertices, filename='skeleton.txt'):
|
| 304 |
+
"""
|
| 305 |
+
save skeleton to txt file, the format follows Rignet (joints, root, hier)
|
| 306 |
+
|
| 307 |
+
if hier_order: the first joint index in bone is root joint index, and parent-child relationship is established in bones.
|
| 308 |
+
else: we set the joint nearest to the mesh center as the root joint, and then build hierarchy starting from root.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
num_joints = pred_joints.shape[0]
|
| 312 |
+
|
| 313 |
+
# assign joint names
|
| 314 |
+
joint_names = [f'joint{i}' for i in range(num_joints)]
|
| 315 |
+
|
| 316 |
+
adjacency = defaultdict(list)
|
| 317 |
+
for bone in pred_bones:
|
| 318 |
+
idx_a, idx_b = bone
|
| 319 |
+
adjacency[idx_a].append(idx_b)
|
| 320 |
+
adjacency[idx_b].append(idx_a)
|
| 321 |
+
|
| 322 |
+
# find root joint
|
| 323 |
+
if hier_order:
|
| 324 |
+
root_idx = pred_root_index
|
| 325 |
+
else:
|
| 326 |
+
centroid = np.mean(vertices, axis=0)
|
| 327 |
+
distances = np.linalg.norm(pred_joints - centroid, axis=1)
|
| 328 |
+
root_idx = np.argmin(distances)
|
| 329 |
+
|
| 330 |
+
root_name = joint_names[root_idx]
|
| 331 |
+
|
| 332 |
+
# build hierarchy
|
| 333 |
+
parent_map = {}
|
| 334 |
+
|
| 335 |
+
if hier_order:
|
| 336 |
+
visited = set()
|
| 337 |
+
|
| 338 |
+
for parent_idx, child_idx in pred_bones:
|
| 339 |
+
if child_idx not in parent_map:
|
| 340 |
+
parent_map[child_idx] = parent_idx
|
| 341 |
+
visited.add(child_idx)
|
| 342 |
+
visited.add(parent_idx)
|
| 343 |
+
|
| 344 |
+
parent_map[root_idx] = None
|
| 345 |
+
|
| 346 |
+
else:
|
| 347 |
+
visited = set([root_idx])
|
| 348 |
+
queue = deque([root_idx])
|
| 349 |
+
parent_map[root_idx] = None
|
| 350 |
+
|
| 351 |
+
while queue:
|
| 352 |
+
current_idx = queue.popleft()
|
| 353 |
+
for neighbor_idx in adjacency[current_idx]:
|
| 354 |
+
if neighbor_idx not in visited:
|
| 355 |
+
parent_map[neighbor_idx] = current_idx
|
| 356 |
+
visited.add(neighbor_idx)
|
| 357 |
+
queue.append(neighbor_idx)
|
| 358 |
+
|
| 359 |
+
if len(visited) != num_joints:
|
| 360 |
+
print(f"bones are not fully connected, leaving {num_joints - len(visited)} joints unconnected.")
|
| 361 |
+
|
| 362 |
+
# save joints
|
| 363 |
+
joints_lines = []
|
| 364 |
+
for idx, coord in enumerate(pred_joints):
|
| 365 |
+
name = joint_names[idx]
|
| 366 |
+
joints_line = f'joints {name} {coord[0]:.8f} {coord[1]:.8f} {coord[2]:.8f}'
|
| 367 |
+
joints_lines.append(joints_line)
|
| 368 |
+
|
| 369 |
+
# save root name
|
| 370 |
+
root_line = f'root {root_name}'
|
| 371 |
+
|
| 372 |
+
# save hierarchy
|
| 373 |
+
hier_lines = []
|
| 374 |
+
for child_idx, parent_idx in parent_map.items():
|
| 375 |
+
if parent_idx is not None:
|
| 376 |
+
parent_name = joint_names[parent_idx]
|
| 377 |
+
child_name = joint_names[child_idx]
|
| 378 |
+
hier_line = f'hier {parent_name} {child_name}'
|
| 379 |
+
hier_lines.append(hier_line)
|
| 380 |
+
|
| 381 |
+
with open(filename, 'w') as file:
|
| 382 |
+
for line in joints_lines:
|
| 383 |
+
file.write(line + '\n')
|
| 384 |
+
|
| 385 |
+
file.write(root_line + '\n')
|
| 386 |
+
|
| 387 |
+
for line in hier_lines:
|
| 388 |
+
file.write(line + '\n')
|
| 389 |
+
|
| 390 |
+
def save_skeleton_obj(joints, bones, save_path, root_index=None, radius_sphere=0.01,
|
| 391 |
+
radius_bone=0.005, segments=16, stacks=16, use_cone=False):
|
| 392 |
+
"""
|
| 393 |
+
Save skeletons to obj file, each connection contains two red spheres (joint) and one blue cylinder (bone).
|
| 394 |
+
if root index is known, set root sphere to green.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
all_vertices = []
|
| 398 |
+
all_colors = []
|
| 399 |
+
all_faces = []
|
| 400 |
+
vertex_offset = 0
|
| 401 |
+
|
| 402 |
+
# create spheres for joints
|
| 403 |
+
for i, joint in enumerate(joints):
|
| 404 |
+
# define color
|
| 405 |
+
if root_index is not None and i == root_index:
|
| 406 |
+
color = (0, 1, 0) # green for root joint
|
| 407 |
+
else:
|
| 408 |
+
color = (1, 0, 0) # red for other joints
|
| 409 |
+
|
| 410 |
+
# create joint sphere
|
| 411 |
+
sphere_vertices, sphere_faces = create_sphere(joint, radius=radius_sphere, segments=segments, stacks=stacks)
|
| 412 |
+
all_vertices.extend(sphere_vertices)
|
| 413 |
+
all_colors.extend([color] * len(sphere_vertices))
|
| 414 |
+
|
| 415 |
+
# adjust face index
|
| 416 |
+
adjusted_sphere_faces = [(v1 + vertex_offset, v2 + vertex_offset, v3 + vertex_offset) for (v1, v2, v3) in sphere_faces]
|
| 417 |
+
all_faces.extend(adjusted_sphere_faces)
|
| 418 |
+
vertex_offset += len(sphere_vertices)
|
| 419 |
+
|
| 420 |
+
# create bones
|
| 421 |
+
for bone in bones:
|
| 422 |
+
parent_idx, child_idx = bone
|
| 423 |
+
parent = joints[parent_idx]
|
| 424 |
+
child = joints[child_idx]
|
| 425 |
+
|
| 426 |
+
try:
|
| 427 |
+
bone_vertices, bone_faces = create_bone(parent, child, radius=radius_bone, segments=segments, use_cone=use_cone)
|
| 428 |
+
except ValueError as e:
|
| 429 |
+
print(f"Skipping connection {parent_idx}-{child_idx}, reason: {e}")
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
all_vertices.extend(bone_vertices)
|
| 433 |
+
all_colors.extend([(0, 0, 1)] * len(bone_vertices)) # blue
|
| 434 |
+
|
| 435 |
+
# adjust face index
|
| 436 |
+
adjusted_bone_faces = [(v1 + vertex_offset, v2 + vertex_offset, v3 + vertex_offset) for (v1, v2, v3) in bone_faces]
|
| 437 |
+
all_faces.extend(adjusted_bone_faces)
|
| 438 |
+
vertex_offset += len(bone_vertices)
|
| 439 |
+
|
| 440 |
+
# save to obj
|
| 441 |
+
obj_lines = []
|
| 442 |
+
for v, c in zip(all_vertices, all_colors):
|
| 443 |
+
obj_lines.append(f"v {v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]}")
|
| 444 |
+
obj_lines.append("")
|
| 445 |
+
|
| 446 |
+
for face in all_faces:
|
| 447 |
+
obj_lines.append(f"f {face[0]} {face[1]} {face[2]}")
|
| 448 |
+
|
| 449 |
+
with open(save_path, 'w') as obj_file:
|
| 450 |
+
obj_file.write("\n".join(obj_lines))
|
| 451 |
+
|
| 452 |
+
def create_sphere(center, radius=0.01, segments=16, stacks=16):
|
| 453 |
+
vertices = []
|
| 454 |
+
faces = []
|
| 455 |
+
for i in range(stacks + 1):
|
| 456 |
+
lat = np.pi / 2 - i * np.pi / stacks
|
| 457 |
+
xy = radius * np.cos(lat)
|
| 458 |
+
z = radius * np.sin(lat)
|
| 459 |
+
for j in range(segments):
|
| 460 |
+
lon = j * 2 * np.pi / segments
|
| 461 |
+
x = xy * np.cos(lon) + center[0]
|
| 462 |
+
y = xy * np.sin(lon) + center[1]
|
| 463 |
+
vertices.append((x, y, z + center[2]))
|
| 464 |
+
for i in range(stacks):
|
| 465 |
+
for j in range(segments):
|
| 466 |
+
first = i * segments + j
|
| 467 |
+
second = first + segments
|
| 468 |
+
third = first + 1 if (j + 1) < segments else i * segments
|
| 469 |
+
fourth = second + 1 if (j + 1) < segments else (i + 1) * segments
|
| 470 |
+
faces.append((first + 1, second + 1, fourth + 1))
|
| 471 |
+
faces.append((first + 1, fourth + 1, third + 1))
|
| 472 |
+
return vertices, faces
|
| 473 |
+
|
| 474 |
+
def create_bone(start, end, radius=0.005, segments=16, use_cone=False):
|
| 475 |
+
dir_vector = np.array(end) - np.array(start)
|
| 476 |
+
height = np.linalg.norm(dir_vector)
|
| 477 |
+
if height == 0:
|
| 478 |
+
raise ValueError("Start and end points cannot be the same for a cone.")
|
| 479 |
+
dir_vector = dir_vector / height
|
| 480 |
+
|
| 481 |
+
z = np.array([0, 0, 1])
|
| 482 |
+
if np.allclose(dir_vector, z):
|
| 483 |
+
R = np.identity(3)
|
| 484 |
+
elif np.allclose(dir_vector, -z):
|
| 485 |
+
R = np.array([[-1,0,0],[0,-1,0],[0,0,1]])
|
| 486 |
+
else:
|
| 487 |
+
v = np.cross(z, dir_vector)
|
| 488 |
+
s = np.linalg.norm(v)
|
| 489 |
+
c = np.dot(z, dir_vector)
|
| 490 |
+
kmat = np.array([[0, -v[2], v[1]],
|
| 491 |
+
[v[2], 0, -v[0]],
|
| 492 |
+
[-v[1], v[0], 0]])
|
| 493 |
+
R = np.identity(3) + kmat + np.matmul(kmat, kmat) * ((1 - c) / (s**2))
|
| 494 |
+
|
| 495 |
+
theta = np.linspace(0, 2 * np.pi, segments, endpoint=False)
|
| 496 |
+
base_circle = np.array([np.cos(theta), np.sin(theta), np.zeros(segments)]) * radius
|
| 497 |
+
|
| 498 |
+
vertices = []
|
| 499 |
+
for point in base_circle.T:
|
| 500 |
+
rotated = np.dot(R, point) + np.array(start)
|
| 501 |
+
vertices.append(tuple(rotated))
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
faces = []
|
| 505 |
+
|
| 506 |
+
if use_cone:
|
| 507 |
+
vertices.append(tuple(end))
|
| 508 |
+
|
| 509 |
+
apex_idx = segments + 1
|
| 510 |
+
for i in range(segments):
|
| 511 |
+
next_i = (i + 1) % segments
|
| 512 |
+
faces.append((i + 1, next_i + 1, apex_idx))
|
| 513 |
+
else:
|
| 514 |
+
top_circle = np.array([np.cos(theta), np.sin(theta), np.ones(segments)]) * radius
|
| 515 |
+
for point in top_circle.T:
|
| 516 |
+
point_scaled = np.array([point[0], point[1], height])
|
| 517 |
+
rotated = np.dot(R, point_scaled) + np.array(start)
|
| 518 |
+
vertices.append(tuple(rotated))
|
| 519 |
+
for i in range(segments):
|
| 520 |
+
next_i = (i + 1) % segments
|
| 521 |
+
faces.append((i + 1, next_i + 1, next_i + segments + 1))
|
| 522 |
+
faces.append((i + 1, next_i + segments + 1, i + segments + 1))
|
| 523 |
+
|
| 524 |
+
return vertices, faces
|
| 525 |
+
|
| 526 |
+
def render_mesh_with_skeleton(joints, bones, vertices, faces, output_dir, filename, prefix='pred', root_idx=None):
|
| 527 |
+
"""
|
| 528 |
+
Render the mesh with skeleton using PyRender.
|
| 529 |
+
"""
|
| 530 |
+
loader = DataLoader()
|
| 531 |
+
|
| 532 |
+
raw_size = (960, 960)
|
| 533 |
+
renderer = PyRenderWrapper(raw_size)
|
| 534 |
+
|
| 535 |
+
save_dir = os.path.join(output_dir, 'render_results')
|
| 536 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 537 |
+
|
| 538 |
+
loader.joints = joints
|
| 539 |
+
loader.bones = bones
|
| 540 |
+
loader.root_idx = root_idx
|
| 541 |
+
|
| 542 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 543 |
+
mesh.visual.vertex_colors[:, 3] = 100 # set transparency
|
| 544 |
+
loader.mesh = mesh
|
| 545 |
+
v = mesh.vertices
|
| 546 |
+
xmin, ymin, zmin = v.min(axis=0)
|
| 547 |
+
xmax, ymax, zmax = v.max(axis=0)
|
| 548 |
+
loader.bbox_center = np.array([(xmax + xmin)/2, (ymax + ymin)/2, (zmax + zmin)/2])
|
| 549 |
+
loader.bbox_size = np.array([xmax - xmin, ymax - ymin, zmax - zmin])
|
| 550 |
+
loader.bbox_scale = max(xmax - xmin, ymax - ymin, zmax - zmin)
|
| 551 |
+
loader.normalize_coordinates()
|
| 552 |
+
|
| 553 |
+
input_dict = loader.query_mesh_rig()
|
| 554 |
+
|
| 555 |
+
angles = [0, np.pi/2, np.pi, 3*np.pi/2]
|
| 556 |
+
distance = np.max(loader.bbox_size) * 2
|
| 557 |
+
|
| 558 |
+
subfolder_path = os.path.join(save_dir, filename + '_' + prefix)
|
| 559 |
+
|
| 560 |
+
os.makedirs(subfolder_path, exist_ok=True)
|
| 561 |
+
|
| 562 |
+
for i, angle in enumerate(angles):
|
| 563 |
+
renderer.set_camera_view(angle, loader.bbox_center, distance)
|
| 564 |
+
renderer.align_light_to_camera()
|
| 565 |
+
|
| 566 |
+
color = renderer.render(input_dict)[0]
|
| 567 |
+
|
| 568 |
+
output_filename = f"{filename}_{prefix}_view{i+1}.png"
|
| 569 |
+
output_filepath = os.path.join(subfolder_path, output_filename)
|
| 570 |
+
cv2.imwrite(output_filepath, color)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def save_args(args, output_dir, filename="config.json"):
|
| 574 |
+
args_dict = vars(args)
|
| 575 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 576 |
+
config_path = os.path.join(output_dir, filename)
|
| 577 |
+
with open(config_path, 'w') as f:
|
| 578 |
+
json.dump(args_dict, f, indent=4)
|
utils/skeleton_data_loader.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import torch
|
| 15 |
+
from torch import is_tensor
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 18 |
+
from data_utils.save_npz import normalize_to_unit_cube
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
class SkeletonData(Dataset):
|
| 23 |
+
"""
|
| 24 |
+
A PyTorch Dataset to load and process skeleton data.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, data, args, is_training):
|
| 27 |
+
self.data = data
|
| 28 |
+
|
| 29 |
+
self.input_pc_num = args.input_pc_num
|
| 30 |
+
self.is_training = is_training
|
| 31 |
+
|
| 32 |
+
self.hier_order = args.hier_order
|
| 33 |
+
print(f"[Dataset] Created from {len(self.data)} entries")
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.data)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, idx):
|
| 39 |
+
data = self.data[idx]
|
| 40 |
+
|
| 41 |
+
joints = data['joints']
|
| 42 |
+
vertices = data['vertices']
|
| 43 |
+
pc_normal = data['pc_w_norm']
|
| 44 |
+
|
| 45 |
+
indices = np.random.choice(pc_normal.shape[0], self.input_pc_num, replace=False)
|
| 46 |
+
pc_normal = pc_normal[indices, :]
|
| 47 |
+
|
| 48 |
+
pc_coor = pc_normal[:, :3]
|
| 49 |
+
normal = pc_normal[:, 3:]
|
| 50 |
+
if np.linalg.norm(normal, axis=1, keepdims=True).min() < 0.99:
|
| 51 |
+
print("normal reroll")
|
| 52 |
+
return self.__getitem__(np.random.randint(0, len(self.data)))
|
| 53 |
+
|
| 54 |
+
data_dict = {}
|
| 55 |
+
|
| 56 |
+
# normalize normal
|
| 57 |
+
normal = normal / np.linalg.norm(normal, axis=1, keepdims=True)
|
| 58 |
+
|
| 59 |
+
# scale to -0.5 to 0.5
|
| 60 |
+
_, center, scale = normalize_to_unit_cube(vertices.copy(), scale_factor=0.9995)
|
| 61 |
+
joints = (joints - center) * scale # align joints with pc first
|
| 62 |
+
|
| 63 |
+
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
|
| 64 |
+
pc_center = (bounds[0] + bounds[1])[None, :] / 2
|
| 65 |
+
pc_scale = (bounds[1] - bounds[0]).max() + 1e-5
|
| 66 |
+
pc_coor = (pc_coor - pc_center) / pc_scale
|
| 67 |
+
joints = (joints - pc_center) / pc_scale
|
| 68 |
+
|
| 69 |
+
joints = joints.clip(-0.5, 0.5)
|
| 70 |
+
|
| 71 |
+
data_dict['joints'] = torch.from_numpy(np.asarray(joints).astype(np.float16))
|
| 72 |
+
data_dict['bones'] = torch.from_numpy(data['bones'].astype(np.int64))
|
| 73 |
+
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
|
| 74 |
+
data_dict['pc_normal'] = torch.from_numpy(np.concatenate([pc_coor, normal], axis=-1).astype(np.float16))
|
| 75 |
+
data_dict['vertices'] = torch.from_numpy(data['vertices'].astype(np.float16))
|
| 76 |
+
data_dict['faces'] = torch.from_numpy(data['faces'].astype(np.int64))
|
| 77 |
+
data_dict['uuid'] = data['uuid']
|
| 78 |
+
data_dict['root_index'] = str(data['root_index'])
|
| 79 |
+
data_dict['transform_params'] = torch.tensor([
|
| 80 |
+
center[0], center[1], center[2],
|
| 81 |
+
scale,
|
| 82 |
+
pc_center[0][0], pc_center[0][1], pc_center[0][2],
|
| 83 |
+
pc_scale
|
| 84 |
+
], dtype=torch.float32)
|
| 85 |
+
|
| 86 |
+
return data_dict
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def load(cls, args, is_training=True):
|
| 90 |
+
loaded_data = np.load(args.dataset_path, allow_pickle=True)
|
| 91 |
+
data = []
|
| 92 |
+
for item in loaded_data["arr_0"]:
|
| 93 |
+
data.append(item)
|
| 94 |
+
print(f"[Dataset] Loaded {len(data)} entries")
|
| 95 |
+
return cls(data, args, is_training)
|
| 96 |
+
|
| 97 |
+
|