ckc99u commited on
Commit
c0c8ef0
·
verified ·
1 Parent(s): d1621dc

Upload 33 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,40 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/ar_demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/articulation-xl2.0.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/MagicArticulate_teaser.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/sequence_ordering_demo.gif filter=lfs diff=lfs merge=lfs -text
40
+ data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: MagicArt
3
- emoji: 🏆
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 6.0.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: obj to rig test
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: MagicArt
3
+ emoji: 🏆
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 6.0.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: obj to rig test
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import trimesh
4
+ import numpy as np
5
+ import gradio as gr
6
+ from pathlib import Path
7
+ import tempfile
8
+ import shutil
9
+
10
+ from skeleton_models.skeletongen import SkeletonGPT
11
+ from data_utils.save_npz import normalize_to_unit_cube
12
+ from utils.mesh_to_pc import MeshProcessor
13
+ from utils.save_utils import (
14
+ pred_joints_and_bones,
15
+ save_skeleton_to_txt,
16
+ merge_duplicate_joints_and_fix_bones,
17
+ save_skeleton_obj,
18
+ save_mesh
19
+ )
20
+
21
+ # Global model variable
22
+ model = None
23
+ args_config = None
24
+
25
+ def initialize_model():
26
+ """Initialize the model once at startup"""
27
+ global model, args_config
28
+
29
+ if model is not None:
30
+ return
31
+
32
+ print("Initializing MagicArticulate model...")
33
+
34
+ # Create a simple args object with default parameters
35
+ class Args:
36
+ def __init__(self):
37
+ self.input_pc_num = 8192
38
+ self.num_beams = 1
39
+ self.llm = "facebook/opt-350m"
40
+ self.pad_id = -1
41
+ self.n_discrete_size = 128
42
+ self.n_max_bones = 100
43
+ self.seed = 0
44
+ self.precision = "fp16"
45
+ self.pretrained_weights = "checkpoints/checkpoint_trainonv2_hier.pt" # Default checkpoint
46
+ self.hier_order = False
47
+
48
+ args_config = Args()
49
+
50
+ # Load model
51
+ model = SkeletonGPT(args_config).cuda()
52
+
53
+ # Load pretrained weights
54
+ if os.path.exists(args_config.pretrained_weights):
55
+ pkg = torch.load(args_config.pretrained_weights, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
56
+ model.load_state_dict(pkg["model"])
57
+ model.eval()
58
+ print("Model loaded successfully!")
59
+ else:
60
+ print(f"Warning: Pretrained weights not found at {args_config.pretrained_weights}")
61
+ raise FileNotFoundError("Model checkpoint not found. Please ensure checkpoints are downloaded.")
62
+
63
+ def process_mesh(
64
+ input_file,
65
+ apply_marching_cubes,
66
+ hier_order,
67
+ octree_depth
68
+ ):
69
+ """
70
+ Process the input mesh and generate rigging prediction
71
+
72
+ Args:
73
+ input_file: Uploaded mesh file (.obj, .ply, or .stl)
74
+ apply_marching_cubes: Whether to apply marching cubes
75
+ hier_order: Whether to use hierarchical ordering
76
+ octree_depth: Depth for octree (if using marching cubes)
77
+
78
+ Returns:
79
+ Tuple of (skeleton obj file, rig txt file, normalized mesh file, status message)
80
+ """
81
+ try:
82
+ # Initialize model if not already done
83
+ if model is None:
84
+ initialize_model()
85
+
86
+ # Create temporary output directory
87
+ output_dir = tempfile.mkdtemp()
88
+
89
+ # Get file information
90
+ file_name = Path(input_file).stem
91
+ file_ext = Path(input_file).suffix.lower()
92
+
93
+ # Check file type
94
+ if file_ext not in ['.obj', '.ply', '.stl']:
95
+ return None, None, None, f"Error: Unsupported file type {file_ext}. Please upload .obj, .ply, or .stl file."
96
+
97
+ # Load mesh
98
+ mesh = trimesh.load(input_file, force='mesh')
99
+
100
+ # Convert mesh to point cloud
101
+ print(f"Converting mesh to point cloud (apply_marching_cubes={apply_marching_cubes})...")
102
+ pc_list = MeshProcessor.convert_meshes_to_point_clouds(
103
+ [mesh],
104
+ args_config.input_pc_num,
105
+ apply_marching_cubes=apply_marching_cubes,
106
+ octree_depth=octree_depth
107
+ )
108
+ pc_normal = pc_list[0]
109
+
110
+ # Normalize point cloud
111
+ pc_coor = pc_normal[:, :3]
112
+ normals = pc_normal[:, 3:]
113
+ pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
114
+
115
+ pc_coor = pc_coor.astype(np.float32)
116
+ normals = normals.astype(np.float32)
117
+
118
+ # Calculate transform parameters
119
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
120
+ pc_center = (bounds[0] + bounds[1])[None, :] / 2
121
+ pc_scale = ((bounds[1] - bounds[0]).max() + 1e-5)
122
+
123
+ transform_params = torch.tensor([
124
+ center[0], center[1], center[2],
125
+ scale,
126
+ pc_center[0][0], pc_center[0][1], pc_center[0][2],
127
+ pc_scale
128
+ ], dtype=torch.float32)
129
+
130
+ # Prepare batch data
131
+ pc_normal_tensor = torch.from_numpy(
132
+ np.concatenate([pc_coor, normals], axis=-1).astype(np.float16)
133
+ ).unsqueeze(0).cuda()
134
+
135
+ batch_data = {
136
+ 'pc_normal': pc_normal_tensor,
137
+ 'file_name': [file_name],
138
+ 'transform_params': transform_params.unsqueeze(0).cuda(),
139
+ 'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0).cuda(),
140
+ 'faces': torch.from_numpy(mesh.faces).unsqueeze(0).cuda()
141
+ }
142
+
143
+ # Generate skeleton
144
+ print("Generating skeleton...")
145
+ with torch.no_grad():
146
+ pred_bone_coords = model.generate(batch_data)
147
+
148
+ # Process predictions
149
+ skeleton = pred_bone_coords[0].cpu().numpy()
150
+ pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
151
+
152
+ # Post-process: merge duplicate joints
153
+ if hier_order:
154
+ pred_root_index = pred_bones[0][0]
155
+ pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(
156
+ pred_joints, pred_bones, root_index=pred_root_index
157
+ )
158
+ else:
159
+ pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones)
160
+ pred_root_index = None
161
+
162
+ # Denormalize joints for rig file
163
+ transform_params_np = transform_params.cpu().numpy()
164
+ trans = transform_params_np[:3]
165
+ scale_val = transform_params_np[3]
166
+ pc_trans = transform_params_np[4:7]
167
+ pc_scale_val = transform_params_np[7]
168
+
169
+ pred_joints_denorm = pred_joints * pc_scale_val + pc_trans
170
+ pred_joints_denorm = pred_joints_denorm / scale_val + trans
171
+
172
+ # Save outputs
173
+ skel_obj_path = os.path.join(output_dir, f'{file_name}_skel.obj')
174
+ rig_txt_path = os.path.join(output_dir, f'{file_name}_pred.txt')
175
+ mesh_obj_path = os.path.join(output_dir, f'{file_name}_mesh.obj')
176
+
177
+ # Save skeleton
178
+ save_skeleton_obj(
179
+ pred_joints,
180
+ pred_bones,
181
+ skel_obj_path,
182
+ pred_root_index if hier_order else None,
183
+ use_cone=hier_order
184
+ )
185
+
186
+ # Save rig
187
+ vertices_np = mesh.vertices
188
+ save_skeleton_to_txt(
189
+ pred_joints_denorm,
190
+ pred_bones,
191
+ pred_root_index,
192
+ hier_order,
193
+ vertices_np,
194
+ rig_txt_path
195
+ )
196
+
197
+ # Save normalized mesh
198
+ vertices_norm = (vertices_np - trans) * scale_val
199
+ vertices_norm = (vertices_norm - pc_trans) / pc_scale_val
200
+ save_mesh(vertices_norm, mesh.faces, mesh_obj_path)
201
+
202
+ status_msg = f"✅ Success! Generated skeleton with {len(pred_joints)} joints and {len(pred_bones)} bones."
203
+
204
+ return skel_obj_path, rig_txt_path, mesh_obj_path, status_msg
205
+
206
+ except Exception as e:
207
+ import traceback
208
+ error_msg = f"❌ Error processing mesh: {str(e)}\n{traceback.format_exc()}"
209
+ print(error_msg)
210
+ return None, None, None, error_msg
211
+
212
+ # Create Gradio interface
213
+ def create_interface():
214
+ """Create the Gradio interface"""
215
+
216
+ with gr.Blocks(title="MagicArticulate - 3D Model Rigging") as demo:
217
+ gr.Markdown("""
218
+ # 🪄 MagicArticulate: Make Your 3D Models Articulation-Ready
219
+
220
+ Upload a 3D mesh (.obj, .ply, or .stl) to automatically generate skeletal rigging.
221
+
222
+ **Paper**: [CVPR 2025] MagicArticulate ([Project Page](https://chaoyuesong.github.io/MagicArticulate/))
223
+ """)
224
+
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ gr.Markdown("### Input")
228
+ input_file = gr.File(
229
+ label="Upload 3D Mesh",
230
+ file_types=[".obj", ".ply", ".stl"],
231
+ type="filepath"
232
+ )
233
+
234
+ gr.Markdown("### Options")
235
+ apply_marching_cubes = gr.Checkbox(
236
+ label="Apply Marching Cubes",
237
+ value=False,
238
+ info="Apply marching cubes for mesh processing (slower but more accurate)"
239
+ )
240
+
241
+ hier_order = gr.Checkbox(
242
+ label="Hierarchical Ordering",
243
+ value=False,
244
+ info="Use hierarchical sequence ordering for skeleton generation"
245
+ )
246
+
247
+ octree_depth = gr.Slider(
248
+ minimum=5,
249
+ maximum=9,
250
+ value=7,
251
+ step=1,
252
+ label="Octree Depth",
253
+ info="Depth for octree (only used if Marching Cubes is enabled)"
254
+ )
255
+
256
+ generate_btn = gr.Button("🚀 Generate Rigging", variant="primary", size="lg")
257
+
258
+ with gr.Column(scale=1):
259
+ gr.Markdown("### Output")
260
+ status_text = gr.Textbox(
261
+ label="Status",
262
+ lines=3,
263
+ interactive=False
264
+ )
265
+
266
+ skel_output = gr.File(
267
+ label="📥 Skeleton (.obj)",
268
+ interactive=False
269
+ )
270
+
271
+ rig_output = gr.File(
272
+ label="📥 Rig Prediction (.txt)",
273
+ interactive=False
274
+ )
275
+
276
+ mesh_output = gr.File(
277
+ label="📥 Normalized Mesh (.obj)",
278
+ interactive=False
279
+ )
280
+
281
+ gr.Markdown("""
282
+ ### About
283
+ MagicArticulate automatically generates skeletal structures for 3D models, making them ready for animation.
284
+ The system predicts joint positions and bone connections using a transformer-based approach.
285
+
286
+ **Outputs**:
287
+ - **Skeleton (.obj)**: 3D visualization of the generated skeleton
288
+ - **Rig Prediction (.txt)**: Detailed rigging information (joints, bones, hierarchy)
289
+ - **Normalized Mesh (.obj)**: The input mesh normalized to unit cube
290
+
291
+ **Citation**:
292
+ ```
293
+ @inproceedings{song2025magicarticulate,
294
+ title={MagicArticulate: Make Your 3D Models Articulation-Ready},
295
+ author={Song, Chaoyue and others},
296
+ booktitle={CVPR},
297
+ year={2025}
298
+ }
299
+ ```
300
+ """)
301
+
302
+ # Connect the button to the processing function
303
+ generate_btn.click(
304
+ fn=process_mesh,
305
+ inputs=[input_file, apply_marching_cubes, hier_order, octree_depth],
306
+ outputs=[skel_output, rig_output, mesh_output, status_text]
307
+ )
308
+
309
+ return demo
310
+
311
+ if __name__ == "__main__":
312
+ # Initialize model at startup
313
+ try:
314
+ initialize_model()
315
+ except Exception as e:
316
+ print(f"Warning: Could not initialize model at startup: {e}")
317
+ print("Model will be initialized on first request.")
318
+
319
+ # Launch Gradio app
320
+ demo = create_interface()
321
+ demo.queue()
322
+ demo.launch(
323
+ server_name="0.0.0.0",
324
+ server_port=7860,
325
+ share=False
326
+ )
assets/MagicArticulate_teaser.gif ADDED

Git LFS Details

  • SHA256: ba4f56e7485a641b9e1c62fbd3717cb54c2f3b3fe519263313ae1280f0877aa7
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
assets/ar_demo.gif ADDED

Git LFS Details

  • SHA256: 9b0602987ba2f9299ce051d27e6ff97a1ebe1e02dd0916a4b842e5c9e31f6973
  • Pointer size: 131 Bytes
  • Size of remote file: 776 kB
assets/articulation-xl2.0.png ADDED

Git LFS Details

  • SHA256: 790a198a035f35052a51960b636153d9f51e2cf50f7e0b04eab1e5368f3a07c1
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
assets/data_statistics.png ADDED
assets/sequence_ordering_demo.gif ADDED

Git LFS Details

  • SHA256: e672e1436727a4a08baf06060a39fba73b1a912ad1a7e0c772535f8a18299fd4
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
assets/skeleton_compare.png ADDED
data_utils/README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Preprocessed data
2
+ We provide the preprocessed data that saved in NPZ files, which contain the following information:
3
+ ```
4
+ 'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'pc_w_norm', 'joint_names', 'skinning_weights_value', 'skinning_weights_rows', 'skinning_weights_cols', 'skinning_weights_shape'
5
+ ```
6
+ You can check `read_npz.py` for how to read the NPZ files and `save_npz.py` for how we save them.
7
+
8
+ Before saving them into NPZ files, we extract mesh(.obj) and rig(.txt) from downloaded 3D models from Objaverse-XL using Blender. The rig file follows the format in [RigNet](https://github.com/zhan-xu/RigNet), which includes the following entries:
9
+ ```
10
+ joints [joint_name] [x] [y] [z]
11
+ root [root_joint_name]
12
+ skin [vertex_index] [joints_name1] [skinning_weight1] [joints_name2] [skinning_weight2] ...
13
+ hier [parent_joint_name] [child_joint_name]
14
+ ```
15
+ For an example, please see `examples/0a59c5ffa4a1476bac6d540b79947f31.txt`.
16
+
17
+ If you want to convert NPZ file back to OBJ and TXT files, we give an example by running:
18
+ ```
19
+ python convert_npz_to_mesh_rig.py
20
+ ```
21
+
22
+ ## Visualization
23
+ We provide a method for visualizing 3D models with skeleton using [Pyrender](https://github.com/mmatl/pyrender), modified from [Lab4D](https://github.com/lab4d-org/lab4d/tree/ppr/). This visualization also serves as input to the VLM for skeleton quality rating. Make sure you have installed the following packages before running visualization:
24
+ ```
25
+ pip install trimesh opencv-python pyrender
26
+ ```
27
+
28
+ We provide an example to demonstrate the process. For this example, we prepare an OBJ file along with a TXT file containing rigging information. Then, run:
29
+ ```
30
+ python render_data.py
31
+ ```
32
+ You will obtain the following outputs:
33
+
34
+ <p align="center">
35
+ <img width="80%" src="examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png"/>
36
+ </p>
37
+
38
+ ### Reading rig and mesh from GLBs
39
+ We provide the script we use for reading rig (.txt) and mesh (.obj) from glb files. You can run:
40
+ ```
41
+ python read_rig_mesh_from_glb.py
42
+ ```
43
+ Remember to download Blender (we use 4.2.0) and also bpy in your conda environment.
data_utils/clean_skin_in_npz.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import scipy.sparse as sp
16
+ import os
17
+
18
+ def check_and_clean_skinning_weights(file_path, output_path, tolerance=0.1):
19
+ """
20
+ Check if all rows in pc_skinning_weights sum to 1 for each item in the NPZ file.
21
+ Remove invalid items and save a cleaned version.
22
+
23
+ Args:
24
+ file_path: Path to the input NPZ file
25
+ output_path: Path for the cleaned NPZ file
26
+ tolerance: Tolerance for floating point comparison
27
+
28
+ Returns:
29
+ tuple: (cleaned_data_list, removed_indices)
30
+ """
31
+ data_list = np.load(file_path, allow_pickle=True)['arr_0']
32
+
33
+ invalid_indices = []
34
+ valid_data_list = []
35
+
36
+ for idx, data in enumerate(data_list):
37
+ is_valid = True
38
+
39
+ weights_data = data['skinning_weights_value']
40
+ weights_row = data['skinning_weights_row']
41
+ weights_col = data['skinning_weights_col']
42
+ weights_shape = data['skinning_weights_shape']
43
+
44
+ skinning_sparse = sp.coo_matrix(
45
+ (weights_data, (weights_row, weights_col)),
46
+ shape=weights_shape
47
+ )
48
+
49
+ skinning_csr = skinning_sparse.tocsr()
50
+ row_sums = np.array(skinning_csr.sum(axis=1)).flatten()
51
+
52
+ invalid_rows = np.where(np.abs(row_sums - 1.0) > tolerance)[0]
53
+
54
+ if len(invalid_rows) > 0:
55
+ min_sum = np.min(row_sums)
56
+ max_sum = np.max(row_sums)
57
+ invalid_indices.append((data['uuid'], f"{len(invalid_rows)} rows, range: [{min_sum:.6f}, {max_sum:.6f}]"))
58
+ is_valid = False
59
+
60
+ if is_valid:
61
+ valid_data_list.append(data)
62
+
63
+ # Save the cleaned data
64
+ if valid_data_list:
65
+ np.savez_compressed(output_path, valid_data_list, allow_pickle=True)
66
+ print(f"Saved {len(valid_data_list)} valid items to {output_path}")
67
+
68
+ return valid_data_list, invalid_indices
69
+
70
+ def main():
71
+ # File paths
72
+ file_path = "articulation_xlv2_train.npz" # "articulation_xlv2_test.npz"
73
+ log_file = "invalid_skinning_weights_intrain.txt" # "invalid_skinning_weights_intest.txt"
74
+ output_path = "articulation_xlv2_train_updated.npz" # "articulation_xlv2_test_updated.npz"
75
+
76
+ # Clean the data
77
+ valid_data, invalid_indices = check_and_clean_skinning_weights(file_path, output_path)
78
+
79
+ # Log the results
80
+ with open(log_file, "w") as f:
81
+ f.write(f"Original file: {file_path}\n")
82
+ f.write(f"Cleaned file: {output_path}\n")
83
+ f.write(f"Total items: {len(np.load(file_path, allow_pickle=True)['arr_0'])}\n")
84
+ f.write(f"Valid items: {len(valid_data)}\n")
85
+ f.write(f"Removed items: {len(invalid_indices)}\n\n")
86
+
87
+ if invalid_indices:
88
+ f.write("Details of removed items:\n")
89
+ for idx, details in invalid_indices:
90
+ f.write(f" Index {idx}: {details}\n")
91
+
92
+ print(f"Cleaning complete. Results written to {log_file}")
93
+
94
+ if __name__ == "__main__":
95
+ main()
data_utils/convert_npz_to_mesh_rig.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ You can convert npz file back to obj(mesh) and txt(rig) files using this python script.
16
+ """
17
+ import os
18
+ import numpy as np
19
+ import scipy.sparse as sp
20
+
21
+ def export_obj(vertices, faces, normals, output_path):
22
+ with open(output_path, 'w') as f:
23
+ for v in vertices:
24
+ f.write(f"v {v[0]} {v[1]} {v[2]}\n")
25
+ for n in normals:
26
+ f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
27
+ for i, face in enumerate(faces):
28
+ # OBJ format is 1-based, so we add 1 to all indices
29
+ f.write(f"f {face[0]+1}//{face[0]+1} {face[1]+1}//{face[1]+1} {face[2]+1}//{face[2]+1}\n")
30
+
31
+ def export_rig_txt(joints, bones, root_index, joint_names, skinning_weights, output_path):
32
+ """
33
+ joints [joint_name] [x] [y] [z]
34
+ root [root_joint_name]
35
+ skin [vertex_index] [joint_name1] [weight1] [joint_name2] [weight2] ...
36
+ hier [parent_joint_name] [child_joint_name]
37
+ """
38
+ n_joints = len(joints)
39
+ n_verts = skinning_weights.shape[0] # (n_vertex, n_joints)
40
+
41
+ with open(output_path, 'w') as f:
42
+ # 1) joints
43
+ for i in range(n_joints):
44
+ x, y, z = joints[i]
45
+ jn = joint_names[i]
46
+ f.write(f"joints {jn} {x} {y} {z}\n")
47
+
48
+ # 2) root
49
+ root_name = joint_names[root_index]
50
+ f.write(f"root {root_name}\n")
51
+
52
+ # 3) skin
53
+ for vidx in range(n_verts):
54
+ row_weights = skinning_weights[vidx]
55
+ non_zero_indices = np.where(row_weights != 0)[0]
56
+ if len(non_zero_indices) == 0:
57
+ continue
58
+
59
+ line_parts = [f"skin {vidx}"] # vertex_idx
60
+ for jidx in non_zero_indices:
61
+ w = row_weights[jidx]
62
+ jn = joint_names[jidx]
63
+ line_parts.append(jn)
64
+ line_parts.append(str(w))
65
+
66
+ f.write(" ".join(line_parts) + "\n")
67
+
68
+ # 4) hier
69
+ for p_idx, c_idx in bones:
70
+ p_name = joint_names[p_idx]
71
+ c_name = joint_names[c_idx]
72
+ f.write(f"hier {p_name} {c_name}\n")
73
+
74
+ if __name__ == "__main__":
75
+
76
+ data = np.load('articulation_xlv2_test.npz', allow_pickle=True)
77
+ data_list = data['arr_0']
78
+
79
+ print(f"Loaded {len(data_list)} data entries")
80
+
81
+ model_data = data_list[0]
82
+ print("Data keys:", model_data.keys())
83
+ # 'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'joint_names',
84
+ # 'skinning_weights_value', 'skinning_weights_row', 'skinning_weights_col', 'skinning_weights_shape'
85
+
86
+ vertices = model_data['vertices'] # (n_vertex, 3)
87
+ faces = model_data['faces'] # (n_faces, 3)
88
+ normals = model_data['normals'] # (n_vertex, 3)
89
+ joints = model_data['joints'] # (n_joints, 3)
90
+ bones = model_data['bones'] # (n_bones, 2)
91
+ root_index = model_data['root_index'] # int
92
+ joint_names = model_data['joint_names'] # list of str
93
+ uuid_str = model_data['uuid']
94
+
95
+ skin_val = model_data['skinning_weights_value']
96
+ skin_row = model_data['skinning_weights_row']
97
+ skin_col = model_data['skinning_weights_col']
98
+ skin_shape = model_data['skinning_weights_shape']
99
+ skin_sparse = sp.coo_matrix((skin_val, (skin_row, skin_col)), shape=skin_shape)
100
+ skinning_weights = skin_sparse.toarray() # (n_vertex, n_joints)
101
+
102
+ obj_path = f"{uuid_str}.obj"
103
+ export_obj(vertices, faces, normals, obj_path)
104
+ rig_txt_path = f"{uuid_str}.txt"
105
+ export_rig_txt(joints, bones, root_index, joint_names, skinning_weights, rig_txt_path)
106
+
107
+ print("Done!")
data_utils/data_loader.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import glob
16
+ import numpy as np
17
+ import trimesh
18
+
19
+ class DataLoader:
20
+ def __init__(self):
21
+ self.joint_name_to_idx = {}
22
+
23
+ def load_rig_data(self, rig_path):
24
+ joints = []
25
+ joints_names = []
26
+ bones = []
27
+
28
+ with open(rig_path, 'r') as f:
29
+ for line in f:
30
+ parts = line.strip().split()
31
+ if parts[0] == 'joints':
32
+ joint_name = parts[1]
33
+ joint_pos = [float(parts[2]), float(parts[3]), float(parts[4])]
34
+ self.joint_name_to_idx[joint_name] = len(joints)
35
+ joints.append(joint_pos)
36
+ joints_names.append(joint_name)
37
+ elif parts[0] == 'root':
38
+ self.root_name = parts[1]
39
+ elif parts[0] == 'hier':
40
+ parent_joint = self.joint_name_to_idx[parts[1]]
41
+ child_joint = self.joint_name_to_idx[parts[2]]
42
+ bones.append([parent_joint, child_joint])
43
+
44
+ self.joints = np.array(joints)
45
+ self.bones = np.array(bones)
46
+ self.joints_names = joints_names
47
+ self.root_idx = None
48
+ if self.root_name is not None:
49
+ self.root_idx = self.joint_name_to_idx[self.root_name]
50
+
51
+ def load_mesh(self, mesh_path):
52
+ mesh = trimesh.load(mesh_path, process=False)
53
+ mesh.visual.vertex_colors[:, 3] = 100 # set transparency
54
+ self.mesh = mesh
55
+
56
+ # Compute the centroid normal of the mesh
57
+ v = self.mesh.vertices
58
+ xmin, ymin, zmin = v.min(axis=0)
59
+ xmax, ymax, zmax = v.max(axis=0)
60
+ self.bbox_center = np.array([(xmax + xmin)/2, (ymax + ymin)/2, (zmax + zmin)/2])
61
+ self.bbox_size = np.array([xmax - xmin, ymax - ymin, zmax - zmin])
62
+ self.bbox_scale = max(xmax - xmin, ymax - ymin, zmax - zmin)
63
+
64
+ normal = mesh.center_mass - self.bbox_center
65
+ normal = normal / (np.linalg.norm(normal)+1e-5)
66
+
67
+ # Choose axis order based on normal direction
68
+ if abs(normal[1]) > abs(normal[2]): # if Y component is dominant
69
+ self.axis_order = [0, 1, 2] # swapping Y and Z
70
+ else:
71
+ self.axis_order =[0, 2, 1] # keep default order
72
+
73
+ self.mesh.vertices = self.mesh.vertices[:, self.axis_order]
74
+ self.joints = self.joints[:, self.axis_order]
75
+ self.normalize_coordinates()
76
+
77
+ def normalize_coordinates(self):
78
+
79
+ # Compute scale and offset
80
+ scale = 1.0 / (self.bbox_scale+1e-5)
81
+ offset = -self.bbox_center
82
+
83
+ self.mesh.vertices = (self.mesh.vertices + offset) * scale
84
+ self.joints = (self.joints + offset) * scale
85
+
86
+ # Calculate appropriate radii based on the mean size
87
+ self.joint_radius = 0.01
88
+ self.bone_radius = 0.005
89
+
90
+ def query_mesh_rig(self):
91
+
92
+ input_dict = {"shape": self.mesh}
93
+
94
+ # Create joints as spheres
95
+ joint_meshes = []
96
+ for i, joint in enumerate(self.joints):
97
+
98
+ sphere = trimesh.creation.icosphere(
99
+ radius=self.joint_radius, subdivisions=2
100
+ )
101
+ sphere.apply_translation(joint)
102
+ if i == self.root_idx:
103
+ # root green
104
+ sphere.visual.vertex_colors = [0, 255, 0, 255]
105
+ else:
106
+ sphere.visual.vertex_colors = [0, 0, 255, 255]
107
+
108
+ joint_meshes.append(sphere)
109
+ input_dict["joint_meshes"] = trimesh.util.concatenate(joint_meshes)
110
+
111
+ # Create bones as cylinders
112
+ bone_meshes = []
113
+ for bone in self.bones:
114
+ start, end = self.joints[bone[0]], self.joints[bone[1]]
115
+ cylinder = trimesh.creation.cylinder(radius=self.bone_radius, segment=np.array([[0, 0, 0], end - start]))
116
+ cylinder.apply_translation(start)
117
+ cylinder.visual.vertex_colors = [255, 0, 0, 255] #[0, 0, 255, 255] # blue
118
+ bone_meshes.append(cylinder)
119
+ input_dict["bone_meshes"] = trimesh.util.concatenate(bone_meshes)
120
+
121
+ return input_dict
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.obj ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31.txt ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/examples/0a59c5ffa4a1476bac6d540b79947f31_render_results.png ADDED

Git LFS Details

  • SHA256: c2bdbf4ee74444b43fa8343d710fef1bc1680f37ade69e620fcd27997cfa7a5e
  • Pointer size: 131 Bytes
  • Size of remote file: 177 kB
data_utils/issue_data_list.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0b1f1ccb-db41-5689-b363-fd8ca0145041
2
+ d4705a2d-2dbf-5175-9fd0-b0cc538b9c4d
3
+ 12b3d88d-2845-57b7-b483-d3a766beeb0e
4
+ 778505b7-63da-5c08-bad7-6935fcd73cec
5
+ 35ed271f-e9d7-528f-b165-e25004ef802b
6
+ 0096279cc46c4d1d8e8611e611e2418b
7
+ 00ea25ccad8344cbaedc89d70bb75a49
8
+ 08b617be44b6466584ba9624f857222c
9
+ 0998722861ba489695ad8bd4456e76e6
10
+ 0bd786e936774176ac474694b0f6f876
11
+ 0c1a7657bea0421dadef56e2080f0297
12
+ 1073c44309524810b6cd4cef2d6e8008
13
+ 10b9c6e9bf214dc39476161dfe2eaa8a
14
+ 147df2ee69df488eb6cb2f88f2f703bb
15
+ 18ff6fa66b0d483a8758e4602e5b70b0
16
+ 1cf88736c59a43c88ba7dac44c929dab
17
+ 1e9544eea98d417db87347dcc16cb69e
18
+ 21a4bc038cbd415b8e09566148c87c46
19
+ 2809e172066d4140b1ddc9356490191a
20
+ 28483d55555f433d8fde4ba141ad5271
21
+ 31829af6c72146519d348a6d4d2bcc8b
22
+ 32202338cd5c40beace31deeacd598e5
23
+ 37fe21828c37413986a07a1bf8c75c93
24
+ 3857965c400c47c9a846c01eb1f36ed5
25
+ 404e622bdfd14ab693640ff86c131973
26
+ 44f8486a0b2c4f9489fc3912b2dcf880
27
+ 49580a36b07d47808aa91db6e2b9fcdd
28
+ 4db51555e8fd48a0905ecee93730f863
29
+ 57a9d6f9fec7430bae67d7d7a9bfdd2c
30
+ 593eeb44d67c49499d3580d908b9f5cd
31
+ 5a571bea2d0c4ad5b2cc912c3dc37a59
32
+ 5cd1f275bdb34d939ffaa07a641a2eef
33
+ 60ab9787fde64199ab59b728276b5cd8
34
+ 63453d744e3844d48bc9a7bedfe586a7
35
+ 6caf784e33084b1389fdea4043560d3f
36
+ 725ce5eae96b4602a3b8a30f73dcbc4c
37
+ 7f9c3d9ccbd949449f25f3711780c1e7
38
+ 80ff2e88de2144bbb21d231db5a02000
39
+ 835174fcce4a4969851ca1846b92036a
40
+ 85b73c92393e453faf0f7ec82d40720e
41
+ 860911c447744c0396b618db994c535e
42
+ 86d6d90704ff4e9c8fc0f0751bd837a2
43
+ 934b27da5e4249978bfa9c190ec01f9a
44
+ 968aecc8c38246f8af3d0d7fa169ca8f
45
+ 9fc1cb45c8404517aa8cee3bb47c14fd
46
+ a65a935fd54b4159a2687bffef7cbf81
47
+ af2f7b1678ea4194a9b8235e7dfd23b3
48
+ b4cd213509ec4dcba41a280b4b013e63
49
+ be7a64227e1f4f13b86389edc4926dfa
50
+ bff3cd47d0574f73980b3af9f7790c58
51
+ c8ac24a9bf2647fb9e7565eaf3a28558
52
+ cc1f905b148c4378ad46a40da72e839f
53
+ ce50fe2e6a654a3bafab950c0f101e59
54
+ d270505df059467e8fa17974f075f3cf
55
+ d476d6bfc0364001a6cc73877a59ca65
56
+ d9a5b67b5c9142e984f76b1afec1939b
57
+ da9cb8ac53274b9bbd9467b7d83c85fb
58
+ dc48f3ab2b2844eba788898509a52806
59
+ e1817fcc5d614723bcb1f49491fe3ed0
60
+ f1fbc33234374c3a911148a453399186
61
+ faab16de19484746a4716cb00b738f8e
62
+ fdb767e69a0748c6bcdfe8764772c0d4
63
+ ff8ec56b0c664b438d36e84882b304f4
64
+ 03ea3bf9d47e4e5789d027279e6edbbb
65
+ 064a05ca3df84e3fbf900f9a1df75577
66
+ 0ada42e959504b47ba58ca331a8d8549
67
+ 112ae8160af54eeea6b2483b903634f4
68
+ 156d6ab3d495476c997887c092aff781
69
+ 1c92543b1e9245e0a2c1e3770a0e3d11
70
+ 1e041df547e64db9aaa8d79218d880a8
71
+ 1e34fd79cbb24db4952db6e9642881d3
72
+ 1ec08e1e74d04354ac7085c004b01c2c
73
+ 20dd7f7bdc9a4c36aef491f12afa14d8
74
+ 242e99d9fe2f4eec91841fd3e8b01021
75
+ 27dbf22159a5464687f4ed9b347257d3
76
+ 28647ae054d74d2e9cac4a3dda31bb55
77
+ 29ff70f5772747f89b0db4aae9c0ade6
78
+ 2b03620bba824c1ea67945abd5c043f2
79
+ 314d74658df6431ea50bede8512882cc
80
+ 38f052a2027346e2943b4c76d2572415
81
+ 3dbaadb244e44f59b5a6b7490aac6883
82
+ 400dbd97e4e6429cab24fab8b5a3d845
83
+ 41790f8edba642ffa281a0660f318db4
84
+ 4c60ff4ebef241deae699ec8d2de86b5
85
+ 5de63c02a4374605acb69691450e6653
86
+ 65df530434624400b030da4579baa4b6
87
+ 66c66c960e1c4b3aab5f2792f5e71add
88
+ 6abf66991f584f1ba45d7297f3a128d4
89
+ 6dd6b05e20604f478d9fd868528b275f
90
+ 6f76008a68074d2bb59a0189f558ae34
91
+ 8bb433dfbef3479cbaa3bcdf63b5b6a2
92
+ 9338c7dbf4054c608c17353358cdb7c6
93
+ 9544bb7b09874f13a5ecd0429379cbd8
94
+ 95d2df27650f4beb8d208a21db7366d9
95
+ 96d50c0f7f6a40ad9e5ae39537d1062e
96
+ 9e7e71c08e5b4ff9b510afbfb2067152
97
+ a6cce2749dfb4b4d89c0dc3460ea9d3b
98
+ ab7e81a8a26d43ecb3131729a999ddcd
99
+ adae06ba4b7a4cbeab892957bc40331b
100
+ ba46772fa0234625832da0582c2f615c
101
+ c4f57ce4bc2b4c46a32414515ba991e9
102
+ cf09886dc98f4666bed77d6b51a4ef67
103
+ cfde2bfa5c634a788c2c4c4480f53ba7
104
+ d0008363ca6c4ea9976494eff45e90bb
105
+ d403eef8a45d485e905b968cc0a1670a
106
+ dc8d45c7ae7f453e9f861c79a40d9265
107
+ eb8e71b3a22f4e719d8157831c408a6e
108
+ ed896088728f4779b2fd9aa7f527e880
109
+ f06a196aea294b0fa05dee4be971a12c
110
+ f3e1bd29da234c8e89e0f208487fe31c
111
+ f84ffc38cbb9400ca31be98fe89abb01
112
+ fa31faff8ec04fa49e72e6266dc14cc4
113
+ fb6bd558e5ff4d3b8709a39d6280460b
114
+ 808f9ffa-c14a-5d78-b8bf-197bc1f0b29c
115
+ e1740d44-9be4-58cf-a3e6-f8208b9cdfc6
116
+ 4acf0253-00b8-5cca-be94-1f2af5bd72ba
117
+ 0c94fe68-2983-52db-822e-6ea63bd54f65
118
+ ff9b4de9-a702-5221-bc26-f0c7ec8c4c51
119
+ b927ce627b6841a688067331853302d6
120
+ ccfad91e-e66d-5cc3-aff8-99f5b3a824fd
121
+ 25434b7c-4ab4-58cd-900f-aa1bfcf53233
122
+ 23d9764b-5035-5025-aae1-2788c1942a7c
123
+ ecbc08ea-5f9d-5d2f-a496-77ec128bd3fe
data_utils/pyrender_wrapper.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/lab4d-org/lab4d
2
+
3
+ import os
4
+ import numpy as np
5
+ import cv2
6
+ import pyrender
7
+ import trimesh
8
+ from pyrender import (
9
+ IntrinsicsCamera,
10
+ Mesh,
11
+ Node,
12
+ Scene,
13
+ OffscreenRenderer,
14
+ MetallicRoughnessMaterial,
15
+ RenderFlags
16
+ )
17
+
18
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
19
+
20
+ def look_at(eye, center, up):
21
+ """Create a look-at (view) matrix."""
22
+ f = np.array(center, dtype=np.float32) - np.array(eye, dtype=np.float32)
23
+ f /= np.linalg.norm(f)
24
+
25
+ u = np.array(up, dtype=np.float32)
26
+ u /= np.linalg.norm(u)
27
+
28
+ s = np.cross(f, u)
29
+ u = np.cross(s, f)
30
+
31
+ m = np.identity(4, dtype=np.float32)
32
+ m[0, :3] = s
33
+ m[1, :3] = u
34
+ m[2, :3] = -f
35
+ m[:3, 3] = -np.matmul(m[:3, :3], np.array(eye, dtype=np.float32))
36
+
37
+ return m
38
+
39
+ class PyRenderWrapper:
40
+ def __init__(self, image_size=(1024, 1024)) -> None:
41
+ # renderer
42
+ self.image_size = image_size
43
+ render_size = max(image_size)
44
+ self.r = OffscreenRenderer(render_size, render_size)
45
+ self.intrinsics = IntrinsicsCamera(
46
+ render_size, render_size, render_size / 2, render_size / 2
47
+ )
48
+ # light
49
+ self.light_pose = np.eye(4)
50
+ self.set_light_topdown()
51
+ self.direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=5.0)
52
+ self.material = MetallicRoughnessMaterial(
53
+ roughnessFactor=0.75, metallicFactor=0.75, alphaMode="BLEND"
54
+ )
55
+ self.init_camera()
56
+
57
+ def init_camera(self):
58
+ self.flip_pose = np.eye(4)
59
+ self.set_camera(np.eye(4))
60
+
61
+ def set_camera(self, scene_to_cam):
62
+ # object to camera transforms
63
+ self.scene_to_cam = self.flip_pose @ scene_to_cam
64
+
65
+ def set_light_topdown(self, gl=False):
66
+ # top down light, slightly closer to the camera
67
+ if gl:
68
+ rot = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0]
69
+ else:
70
+ rot = cv2.Rodrigues(np.asarray([np.pi / 2, 0, 0]))[0]
71
+ self.light_pose[:3, :3] = rot
72
+
73
+ def align_light_to_camera(self):
74
+ self.light_pose = np.linalg.inv(self.scene_to_cam)
75
+
76
+ def set_intrinsics(self, intrinsics):
77
+ """
78
+ Args:
79
+ intrinsics: (4,) fx,fy,px,py
80
+ """
81
+ self.intrinsics = IntrinsicsCamera(
82
+ intrinsics[0], intrinsics[1], intrinsics[2], intrinsics[3]
83
+ )
84
+
85
+ def get_cam_to_scene(self):
86
+ cam_to_scene = np.eye(4)
87
+ cam_to_scene[:3, :3] = self.scene_to_cam[:3, :3].T
88
+ cam_to_scene[:3, 3] = -self.scene_to_cam[:3, :3].T @ self.scene_to_cam[:3, 3]
89
+ return cam_to_scene
90
+
91
+ def set_camera_view(self, angle, bbox_center, distance=2.0):
92
+ # Calculate camera position based on angle and distance from bounding box center
93
+ camera_position = bbox_center + distance * np.array([np.sin(angle), 0, np.cos(angle)], dtype=np.float32)
94
+ look_at_matrix = look_at(camera_position, bbox_center, [0, 1, 0])
95
+ self.scene_to_cam = look_at_matrix @ self.flip_pose
96
+
97
+ def render(self, input_dict):
98
+ # Create separate scenes for transparent objects (mesh) and solid objects (joints and bones)
99
+ scene_transparent = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0]) * 0.1)
100
+ scene_solid = Scene(ambient_light=np.array([1.0, 1.0, 1.0, 1.0]) * 0.1)
101
+
102
+ mesh_pyrender = Mesh.from_trimesh(input_dict["shape"], smooth=False)
103
+ mesh_pyrender.primitives[0].material = self.material
104
+ scene_transparent.add(mesh_pyrender, pose=np.eye(4), name="shape")
105
+
106
+ if "joint_meshes" in input_dict:
107
+ joints_pyrender = Mesh.from_trimesh(input_dict["joint_meshes"], smooth=False)
108
+ joints_pyrender.primitives[0].material = self.material
109
+ scene_solid.add(joints_pyrender, pose=np.eye(4), name="joints")
110
+
111
+ if "bone_meshes" in input_dict:
112
+ bones_pyrender = Mesh.from_trimesh(input_dict["bone_meshes"], smooth=False)
113
+ bones_pyrender.primitives[0].material = self.material
114
+ scene_solid.add(bones_pyrender, pose=np.eye(4), name="bones")
115
+
116
+ # Camera for both scenes
117
+ scene_transparent.add(self.intrinsics, pose=self.get_cam_to_scene())
118
+ scene_solid.add(self.intrinsics, pose=self.get_cam_to_scene())
119
+
120
+ # Light for both scenes
121
+ scene_transparent.add(self.direc_l, pose=self.light_pose)
122
+ scene_solid.add(self.direc_l, pose=self.light_pose)
123
+
124
+ # Render transparent scene first
125
+ color_transparent, depth_transparent = self.r.render(scene_transparent)
126
+
127
+ # Render solid scene on top
128
+ color_solid, depth_solid = self.r.render(scene_solid)
129
+
130
+ # Combine the two scenes
131
+ color_combined = np.where(depth_solid[..., np.newaxis] == 0, color_transparent, color_solid)
132
+
133
+ return color_combined, depth_solid
134
+ def delete(self):
135
+ self.r.delete()
data_utils/read_npz.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import scipy.sparse as sp
16
+
17
+ # Load the NPZ file
18
+ data = np.load('articulation_xlv2_test.npz', allow_pickle=True)
19
+ data_list = data['arr_0']
20
+
21
+ print(f"Loaded {len(data_list)} data entries")
22
+ print(f"Data keys: {data_list[0].keys()}")
23
+ # 'vertices', 'faces', 'normals', 'joints', 'bones', 'root_index', 'uuid', 'pc_w_norm', 'joint_names', 'skinning_weights_value',
24
+ # 'skinning_weights_row', 'skinning_weights_col', 'skinning_weights_shape'
25
+
26
+ data = data_list[0] # check the first data
27
+
28
+ vertices = data['vertices'] # (n_vertex, 3)
29
+ faces = data['faces'] # (n_faces, 3)
30
+ normals = data['normals'] # (n_vertex, 3)
31
+ joints = data['joints'] # (n_joints, 3)
32
+ bones = data['bones'] # (n_bones, 2)
33
+ pc_w_norm = data['pc_w_norm'] # (8192, 6)
34
+
35
+ # Extract the sparse skinning weights components
36
+ skinning_data = data['skinning_weights_value']
37
+ skinning_rows = data['skinning_weights_row']
38
+ skinning_cols = data['skinning_weights_col']
39
+ skinning_shape = data['skinning_weights_shape']
40
+
41
+ skinning_sparse = sp.coo_matrix((skinning_data, (skinning_rows, skinning_cols)), shape=skinning_shape)
42
+ skinning_weights = skinning_sparse.toarray() # (n_vertex, n_joints)
43
+
data_utils/read_rig_mesh_from_glb.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Blender script for extracting rig (.txt) and mesh (.obj) from glbs.
17
+ This code currently supports GLB files only, but it can be easily modified to load other formats (e.g., FBX, DAE) with minimal changes.
18
+ """
19
+
20
+ import bpy
21
+ import os
22
+ import re
23
+ import json
24
+ import pickle
25
+
26
+ def get_hierarchy_root_joint(joint):
27
+ """
28
+ Function to find the top parent joint node from the given
29
+ 'joint' Blender node (armature bone).
30
+ """
31
+ root_joint = joint
32
+ while root_joint.parent is not None:
33
+ root_joint = root_joint.parent
34
+ return root_joint
35
+
36
+ def get_meshes_and_armatures():
37
+ """
38
+ Function to get all meshes and armatures in the scene
39
+ """
40
+ default_objects = ['Cube', 'Light', 'Camera', 'Icosphere']
41
+ for obj_name in default_objects:
42
+ if obj_name in bpy.data.objects:
43
+ bpy.data.objects.remove(bpy.data.objects[obj_name], do_unlink=True)
44
+
45
+ meshes = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
46
+ armatures = [obj for obj in bpy.context.scene.objects if obj.type == 'ARMATURE']
47
+ return meshes, armatures
48
+
49
+ def get_joint_dict(root):
50
+ """
51
+ Function to create a dictionary of joints from the root joint
52
+ """
53
+ joint_pos = {}
54
+ def traverse_bone(bone):
55
+ joint_pos[bone.name] = {
56
+ 'pos': bone.head_local,
57
+ 'pa': bone.parent.name if bone.parent else 'None',
58
+ 'ch': [child.name for child in bone.children]
59
+ }
60
+ for child in bone.children:
61
+ traverse_bone(child)
62
+
63
+ traverse_bone(root)
64
+ return joint_pos
65
+
66
+ def record_info(root, joint_dict, meshes, mesh_vert_offsets, file_info):
67
+ """
68
+ - root: root joint
69
+ - joint_dict
70
+ - meshes
71
+ - mesh_vert_offsets: for multi-geometry
72
+ - file_info
73
+ """
74
+ skin_records = {}
75
+
76
+ def replace_special_characters(name):
77
+ return re.sub(r'\W+', '_', name)
78
+
79
+ for key, val in joint_dict.items():
80
+ modified_key = replace_special_characters(key)
81
+ file_info.write(f'joints {modified_key} {val["pos"][0]:.8f} {val["pos"][1]:.8f} {val["pos"][2]:.8f}\n')
82
+ file_info.write(f'root {replace_special_characters(root.name)}\n')
83
+
84
+ for mesh_index, mesh in enumerate(meshes):
85
+ vert_offset = mesh_vert_offsets[mesh_index]
86
+ if mesh.type == 'MESH':
87
+ for vtx in mesh.data.vertices:
88
+ weights = {}
89
+ for group in vtx.groups:
90
+ bone_name = replace_special_characters(mesh.vertex_groups[group.group].name)
91
+ weights[bone_name] = group.weight
92
+
93
+ global_vertex_index = vert_offset + vtx.index
94
+
95
+ skin_record = f"skin {global_vertex_index} " + " ".join(f"{bone} {weight:.4f}" for bone, weight in weights.items())
96
+
97
+ if global_vertex_index not in skin_records:
98
+ skin_records[global_vertex_index] = skin_record
99
+ file_info.write(skin_record + "\n")
100
+
101
+ for key, val in joint_dict.items():
102
+ if val['pa'] != 'None':
103
+ parent_name = replace_special_characters(val['pa'])
104
+ child_name = replace_special_characters(key)
105
+ file_info.write(f'hier {parent_name} {child_name}\n')
106
+
107
+
108
+ def record_obj(meshes, file_obj):
109
+ vert_offset = 0
110
+ norm_offset = 0
111
+ mesh_vert_offsets = []
112
+
113
+ for mesh in meshes:
114
+ mesh_vert_offsets.append(vert_offset)
115
+ bpy.context.view_layer.objects.active = mesh
116
+ bpy.ops.object.mode_set(mode='OBJECT')
117
+
118
+ # vertex
119
+ for v in mesh.data.vertices:
120
+ file_obj.write(f"v {v.co[0]} {v.co[1]} {v.co[2]}\n")
121
+ file_obj.write("\n")
122
+
123
+ # normal
124
+ for vn in mesh.data.vertices:
125
+ normal = vn.normal
126
+ file_obj.write(f"vn {normal[0]} {normal[1]} {normal[2]}\n")
127
+ file_obj.write("\n")
128
+
129
+ # face
130
+ for poly in mesh.data.polygons:
131
+ verts = [v + 1 + vert_offset for v in poly.vertices]
132
+ file_obj.write(f"f {verts[0]}//{verts[0]} {verts[1]}//{verts[1]} {verts[2]}//{verts[2]}\n")
133
+
134
+ vert_count = len(mesh.data.vertices)
135
+ vert_offset += vert_count
136
+ norm_offset += vert_count
137
+
138
+ return mesh_vert_offsets
139
+
140
+ def process_glb(glb_path, rigs_dir, meshes_dir):
141
+ base_name = os.path.splitext(os.path.basename(glb_path))[0]
142
+
143
+ obj_name = os.path.join(meshes_dir, f'{base_name}.obj')
144
+ info_name = os.path.join(rigs_dir, f'{base_name}.txt')
145
+
146
+ # Skip processing if rig info file already exists
147
+ if os.path.exists(info_name):
148
+ print(f"{info_name} already exists. Skipping...")
149
+ return
150
+
151
+ if os.path.exists(obj_name):
152
+ print(f"{obj_name} already exists. Skipping...")
153
+ return
154
+
155
+ bpy.ops.wm.read_factory_settings(use_empty=True)
156
+ bpy.ops.import_scene.gltf(filepath=glb_path)
157
+
158
+ meshes, armatures = get_meshes_and_armatures()
159
+
160
+ if not armatures:
161
+ print(f"No armatures found in {glb_path}. Skipping...")
162
+ return
163
+
164
+ root = armatures[0].data.bones[0]
165
+ root_name = get_hierarchy_root_joint(root)
166
+ joint_dict = get_joint_dict(root_name)
167
+
168
+ # save meshes
169
+ with open(obj_name, 'w') as file_obj:
170
+ mesh_vert_offsets = record_obj(meshes, file_obj)
171
+
172
+ # save rigs
173
+ with open(info_name, 'w') as file_info:
174
+ record_info(root_name, joint_dict, meshes, mesh_vert_offsets, file_info)
175
+
176
+ print(f"Processed {glb_path}")
177
+
178
+ if __name__ == '__main__':
179
+
180
+ src_dir = 'glbs'
181
+ rigs_dir = 'rigs'
182
+ meshes_dir = 'meshes'
183
+ # Ensure rigs directory exists
184
+ if not os.path.exists(rigs_dir):
185
+ os.makedirs(rigs_dir)
186
+ if not os.path.exists(meshes_dir):
187
+ os.makedirs(meshes_dir)
188
+
189
+ glb_paths = [os.path.join(src_dir, file) for file in os.listdir(src_dir) if file.endswith('.glb')]
190
+
191
+ print(len(glb_paths))
192
+
193
+ for glb_path in glb_paths:
194
+ try:
195
+ process_glb(glb_path, rigs_dir, meshes_dir)
196
+ except Exception as e:
197
+ with open('error.txt', 'a') as error_file:
198
+ error_file.write(f"{glb_path}: {str(e)}\n")
data_utils/render_data.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import numpy as np
16
+ import cv2
17
+
18
+ from pyrender_wrapper import PyRenderWrapper
19
+ from data_loader import DataLoader
20
+
21
+ def main():
22
+ loader = DataLoader()
23
+
24
+ raw_size = (960, 960)
25
+ renderer = PyRenderWrapper(raw_size)
26
+
27
+ output_dir = 'render_results'
28
+ os.makedirs(output_dir, exist_ok=True)
29
+
30
+ rig_path = 'examples/0a59c5ffa4a1476bac6d540b79947f31.txt'
31
+ mesh_path = rig_path.replace('.txt', '.obj')
32
+
33
+ filename = os.path.splitext(os.path.basename(rig_path))[0]
34
+
35
+ loader.load_rig_data(rig_path)
36
+ loader.load_mesh(mesh_path)
37
+ input_dict = loader.query_mesh_rig()
38
+
39
+ angles = [0, np.pi/2, np.pi, 3*np.pi/2]
40
+
41
+ bbox_center = loader.mesh.bounding_box.centroid
42
+ bbox_size = loader.mesh.bounding_box.extents
43
+ distance = np.max(bbox_size) * 2
44
+
45
+ subfolder_path = os.path.join(output_dir, filename)
46
+
47
+ os.makedirs(subfolder_path, exist_ok=True)
48
+
49
+ for i, angle in enumerate(angles):
50
+ print(f"Rendering view at {np.degrees(angle)} degrees")
51
+
52
+ renderer.set_camera_view(angle, bbox_center, distance)
53
+ renderer.align_light_to_camera()
54
+
55
+ color = renderer.render(input_dict)[0]
56
+
57
+ output_filename = f"{filename}_view{i+1}.png"
58
+ output_filepath = os.path.join(subfolder_path, output_filename)
59
+ cv2.imwrite(output_filepath, color)
60
+ if __name__ == "__main__":
61
+ main()
data_utils/save_npz.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ This python script shows how we process the meshes and rigs from the input folders and save them in a compressed npz file.
16
+ """
17
+ import os
18
+ import numpy as np
19
+ import glob
20
+ import pickle
21
+ from concurrent.futures import ProcessPoolExecutor
22
+ import skimage.measure
23
+ import trimesh
24
+ import mesh2sdf.core
25
+ import scipy.sparse as sp
26
+
27
+ def read_obj_file(file_path):
28
+ vertices = []
29
+ faces = []
30
+ normals = [] # Added normals list
31
+
32
+ with open(file_path, 'r') as file:
33
+ for line in file:
34
+ if line.startswith('v '):
35
+ parts = line.split()[1:]
36
+ vertices.append([float(parts[0]), float(parts[1]), float(parts[2])])
37
+ elif line.startswith('vn '): # Added reading normals
38
+ parts = line.split()[1:]
39
+ normals.append([float(parts[0]), float(parts[1]), float(parts[2])])
40
+ elif line.startswith('f '):
41
+ parts = line.split()[1:]
42
+ # OBJ format is 1-based, we need 0-based for npz
43
+ face = [int(part.split('//')[0]) - 1 for part in parts]
44
+ faces.append(face)
45
+
46
+ return np.array(vertices), np.array(faces), np.array(normals)
47
+
48
+ def read_rig_file(file_path):
49
+ """
50
+ Read rig from txt file, our format is the same as RigNet:
51
+ joints joint_name x y z
52
+ root root_joint_name
53
+ skin vertex_idx joint_name weight joint_name weight ...
54
+ hier parent_joint_name child_joint_name
55
+ """
56
+ joints = []
57
+ bones = []
58
+ joint_names = []
59
+
60
+ joint_mapping = {}
61
+ joint_index = 0
62
+
63
+ skinning_data = {} # Dictionary to store vertex index -> [(joint_idx, weight), ...]
64
+
65
+ with open(file_path, 'r') as file:
66
+ lines = file.readlines()
67
+
68
+ for line in lines:
69
+ parts = line.split()
70
+ if line.startswith('joints'):
71
+ name = parts[1]
72
+ position = [float(parts[2]), float(parts[3]), float(parts[4])]
73
+ joints.append(position)
74
+ joint_names.append(name)
75
+ joint_mapping[name] = joint_index
76
+ joint_index += 1
77
+ elif line.startswith('hier'):
78
+ parent_joint = joint_mapping[parts[1]]
79
+ child_joint = joint_mapping[parts[2]]
80
+ bones.append([parent_joint, child_joint])
81
+ elif line.startswith('root'):
82
+ root = joint_mapping[parts[1]]
83
+ elif line.startswith('skin'):
84
+ vertex_idx = int(parts[1])
85
+
86
+ if vertex_idx not in skinning_data:
87
+ skinning_data[vertex_idx] = []
88
+
89
+ for i in range(2, len(parts), 2):
90
+ if i+1 < len(parts):
91
+ joint_name = parts[i]
92
+ weight = float(parts[i+1])
93
+
94
+ if joint_name in joint_mapping:
95
+ joint_idx = joint_mapping[joint_name]
96
+ skinning_data[vertex_idx].append((joint_idx, weight))
97
+
98
+ return np.array(joints), np.array(bones), root, joint_names, skinning_data
99
+
100
+ def convert_to_sparse_skinning(skinning_data, num_vertices, num_joints):
101
+ """Convert skinning weights to sparse matrix format."""
102
+ rows = []
103
+ cols = []
104
+ data = []
105
+
106
+ for vertex_idx, weights in skinning_data.items():
107
+ for joint_idx, weight in weights:
108
+ rows.append(vertex_idx)
109
+ cols.append(joint_idx)
110
+ data.append(weight)
111
+
112
+ sparse_skinning = sp.coo_matrix((data, (rows, cols)), shape=(num_vertices, num_joints))
113
+
114
+ # Return as tuple of arrays which can be serialized
115
+ return (sparse_skinning.data, sparse_skinning.row, sparse_skinning.col, sparse_skinning.shape)
116
+
117
+ def normalize_to_unit_cube(vertices, normals=None, scale_factor=1.0):
118
+ min_coords = vertices.min(axis=0)
119
+ max_coords = vertices.max(axis=0)
120
+ center = (max_coords + min_coords) / 2.0
121
+
122
+ vertices -= center
123
+ scale = 1.0 / np.abs(vertices).max() * scale_factor
124
+ vertices *= scale
125
+
126
+ if normals is not None:
127
+ # Normalize each normal vector to unit length
128
+ norms = np.linalg.norm(normals, axis=1, keepdims=True)
129
+ normals = normals / (norms+1e-8)
130
+
131
+ return vertices, normals, center, scale
132
+ else:
133
+ return vertices, center, scale
134
+
135
+ def normalize_vertices(vertices, scale=0.9):
136
+ bbmin, bbmax = vertices.min(0), vertices.max(0)
137
+ center = (bbmin + bbmax) * 0.5
138
+ scale = 2.0 * scale / (bbmax - bbmin).max()
139
+ vertices = (vertices - center) * scale
140
+ return vertices, center, scale
141
+
142
+ def export_to_watertight(normalized_mesh, octree_depth: int = 7):
143
+ """
144
+ Convert the non-watertight mesh to watertight.
145
+
146
+ Args:
147
+ input_path (str): normalized path
148
+ octree_depth (int):
149
+
150
+ Returns:
151
+ mesh(trimesh.Trimesh): watertight mesh
152
+
153
+ """
154
+ size = 2 ** octree_depth
155
+ level = 2 / size
156
+
157
+ scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
158
+
159
+ sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
160
+
161
+ vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
162
+
163
+ # watertight mesh
164
+ vertices = vertices / size * 2 - 1 # -1 to 1
165
+ vertices = vertices / to_orig_scale + to_orig_center
166
+ mesh = trimesh.Trimesh(vertices, faces, normals=normals)
167
+
168
+ return mesh
169
+
170
+ def process_mesh_to_pc(mesh, marching_cubes = True, sample_num = 8192):
171
+ if marching_cubes:
172
+ mesh = export_to_watertight(mesh)
173
+ return_mesh = mesh
174
+ points, face_idx = mesh.sample(sample_num, return_index=True)
175
+ points, _, _ = normalize_to_unit_cube(points, scale_factor=0.9995)
176
+ normals = mesh.face_normals[face_idx]
177
+
178
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
179
+ return pc_normal, return_mesh
180
+
181
+ def process_single_file(args):
182
+ mesh_file, rig_file = args
183
+ mesh_name = os.path.basename(mesh_file).split('.')[0]
184
+ rig_name = os.path.basename(rig_file).split('.')[0]
185
+
186
+ if mesh_name != rig_name:
187
+ print(f"Skipping files {mesh_file} and {rig_file} because their names do not match.")
188
+ return None
189
+
190
+ vertices, faces, normals = read_obj_file(mesh_file)
191
+
192
+ joints, bones, root, joint_names, skinning_data = read_rig_file(rig_file)
193
+
194
+ # Normalize the mesh to the unit cube centered at the origin
195
+ vertices, normals, center, scale = normalize_to_unit_cube(vertices, normals, scale_factor=0.5)
196
+
197
+ # Apply the same transformation to joints
198
+ joints -= center
199
+ joints *= scale
200
+
201
+ # Create trimesh object for processing
202
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
203
+
204
+ # Process into point cloud with normals
205
+ pc_normal, _ = process_mesh_to_pc(mesh)
206
+
207
+ # Convert skinning data to sparse format
208
+ sparse_skinning = convert_to_sparse_skinning(skinning_data, len(vertices), len(joints))
209
+
210
+ return {
211
+ 'vertices': vertices,
212
+ 'faces': faces,
213
+ 'normals': normals,
214
+ 'joints': joints,
215
+ 'bones': bones,
216
+ 'root_index': root,
217
+ 'uuid': mesh_name,
218
+ 'pc_w_norm': pc_normal,
219
+ 'joint_names': joint_names,
220
+ 'skinning_weights_value': sparse_skinning[0], # values
221
+ 'skinning_weights_rows': sparse_skinning[1], # row indices
222
+ 'skinning_weights_cols': sparse_skinning[2], # column indices
223
+ 'skinning_weights_shape': sparse_skinning[3] # shape of matrix
224
+ }
225
+
226
+ def process_files(mesh_folder, rig_folder, output_file, num_workers=8):
227
+ file_pairs = []
228
+
229
+ for root, _, files in os.walk(rig_folder):
230
+ for file in files:
231
+ if file.endswith('.txt'):
232
+ rig_file = os.path.join(root, file)
233
+ obj_base_name = os.path.splitext(file)[0]
234
+ mesh_file = os.path.join(mesh_folder, obj_base_name + '.obj')
235
+ if os.path.exists(mesh_file):
236
+ file_pairs.append((mesh_file, rig_file))
237
+ else:
238
+ print(f"Mesh file not found: {mesh_file}")
239
+
240
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
241
+ data_list = list(executor.map(process_single_file, file_pairs))
242
+
243
+ data_list = [data for data in data_list if data is not None]
244
+
245
+ np.savez_compressed(output_file, data_list, allow_pickle=True)
246
+
247
+ def main():
248
+ # Example usage
249
+ mesh_folder = 'meshes/'
250
+ rig_folder = 'rigs/'
251
+ output_file = 'results.npz'
252
+
253
+ process_files(mesh_folder, rig_folder, output_file)
254
+
255
+ if __name__ == "__main__":
256
+ main()
data_utils/update_npz_rm_issue_data.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import os
16
+
17
+ def filter_npz_by_filenames(npz_path, txt_path, output_path):
18
+
19
+ data_list = np.load(npz_path, allow_pickle=True)['arr_0']
20
+
21
+ with open(txt_path, 'r') as f:
22
+ exclude_filenames = set(line.strip() for line in f if line.strip())
23
+
24
+ # Filter the data list
25
+ filtered_data = []
26
+ excluded_count = 0
27
+
28
+ for item in data_list:
29
+
30
+ filename = item['uuid']
31
+
32
+ if filename in exclude_filenames:
33
+ excluded_count += 1
34
+ print(filename)
35
+ else:
36
+ filtered_data.append(item)
37
+
38
+ # Save the filtered data
39
+ kept_count = len(filtered_data)
40
+ total_count = len(data_list)
41
+ print(f"Original items: {total_count}")
42
+ print(f"Kept items: {kept_count}")
43
+ print(f"Removed items: {excluded_count}")
44
+
45
+ print(f"Saving filtered data")
46
+ np.savez_compressed(output_path, filtered_data, allow_pickle=True)
47
+
48
+ def main():
49
+ issue_list = "data_utils/issue_data_list.txt" # Change this to your text file path
50
+ npz_path_train = "articulation_xlv2_train.npz" # Change this to your NPZ file path
51
+ output_path_train = "articulation_xlv2_train_update.npz"
52
+ npz_path_test = "articulation_xlv2_test.npz" # Change this to your NPZ file path
53
+ output_path_test = "articulation_xlv2_test_update.npz"
54
+
55
+ filter_npz_by_filenames(npz_path_train, issue_list, output_path_train)
56
+ filter_npz_by_filenames(npz_path_test, issue_list, output_path_test)
57
+
58
+ if __name__ == "__main__":
59
+ main()
demo.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ import trimesh
17
+ import argparse
18
+ import numpy as np
19
+
20
+ from tqdm import tqdm
21
+ from trimesh import Scene
22
+
23
+ from accelerate import Accelerator
24
+ from accelerate.utils import set_seed
25
+ from accelerate.utils import DistributedDataParallelKwargs
26
+
27
+ from skeleton_models.skeletongen import SkeletonGPT
28
+ from data_utils.save_npz import normalize_to_unit_cube
29
+ from utils.mesh_to_pc import MeshProcessor
30
+ from utils.save_utils import save_mesh, pred_joints_and_bones, save_skeleton_to_txt, save_args, \
31
+ merge_duplicate_joints_and_fix_bones, save_skeleton_obj, render_mesh_with_skeleton
32
+
33
+ class Dataset:
34
+ def __init__(self, input_list, input_pc_num = 8192, apply_marching_cubes = True, octree_depth = 7, output_dir = None):
35
+ super().__init__()
36
+ self.data = []
37
+ self.output_dir = output_dir
38
+
39
+ mesh_list = []
40
+ for input_path in input_list:
41
+ ext = os.path.splitext(input_path)[1].lower()
42
+ if ext in ['.ply', '.stl', '.obj']:
43
+ cur_data = trimesh.load(input_path, force='mesh')
44
+ mesh_list.append(cur_data)
45
+ else:
46
+ print(f"Unsupported file type: {ext}")
47
+ if apply_marching_cubes:
48
+ print("First apply Marching Cubes and then sample point cloud, need time...")
49
+ pc_list = MeshProcessor.convert_meshes_to_point_clouds(mesh_list, input_pc_num, apply_marching_cubes = apply_marching_cubes, octree_depth = octree_depth)
50
+ for input_path, cur_data, mesh in zip(input_list, pc_list, mesh_list):
51
+ self.data.append({'pc_normal': cur_data, 'faces': mesh.faces, 'vertices': mesh.vertices, 'file_name': os.path.splitext(os.path.basename(input_path))[0]})
52
+ print(f"dataset total data samples: {len(self.data)}")
53
+
54
+ def __len__(self):
55
+ return len(self.data)
56
+
57
+ def __getitem__(self, idx):
58
+ data_dict = {}
59
+ data_dict['pc_normal'] = self.data[idx]['pc_normal']
60
+ # normalize pc coor
61
+ pc_coor = data_dict['pc_normal'][:, :3]
62
+ normals = data_dict['pc_normal'][:, 3:]
63
+ pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
64
+
65
+ data_dict['file_name'] = self.data[idx]['file_name']
66
+ pc_coor = pc_coor.astype(np.float32)
67
+ normals = normals.astype(np.float32)
68
+
69
+ point_cloud = trimesh.PointCloud(pc_coor)
70
+ point_cloud.metadata['normals'] = normals
71
+
72
+ try:
73
+ point_cloud.export(os.path.join(self.output_dir, f"{data_dict['file_name']}.ply"))
74
+ except Exception as e:
75
+ print(f"fail to save point clouds: {e}")
76
+
77
+ assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
78
+ data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
79
+
80
+ vertices = self.data[idx]['vertices']
81
+ faces = self.data[idx]['faces']
82
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
83
+ pc_center = (bounds[0] + bounds[1])[None, :] / 2
84
+ pc_scale = ((bounds[1] - bounds[0]).max() + 1e-5)
85
+ data_dict['transform_params'] = torch.tensor([
86
+ center[0], center[1], center[2],
87
+ scale,
88
+ pc_center[0][0], pc_center[0][1], pc_center[0][2],
89
+ pc_scale
90
+ ], dtype=torch.float32)
91
+ data_dict['vertices'] = vertices
92
+ data_dict['faces']= faces
93
+ return data_dict
94
+
95
+ def get_args():
96
+ parser = argparse.ArgumentParser("SkeletonGPT", add_help=False)
97
+
98
+ parser.add_argument("--input_pc_num", default=8192, type=int)
99
+ parser.add_argument("--num_beams", default=1, type=int)
100
+ parser.add_argument('--input_dir', default=None, type=str, help="input mesh directory")
101
+ parser.add_argument('--input_path', default=None, type=str, help="input mesh path")
102
+ parser.add_argument("--output_dir", default="outputs", type=str)
103
+ parser.add_argument('--llm', default="facebook/opt-350m", type=str, help="The LLM backend")
104
+ parser.add_argument("--pad_id", default=-1, type=int, help="padding id")
105
+ parser.add_argument("--n_discrete_size", default=128, type=int, help="discretized 3D space")
106
+ parser.add_argument("--n_max_bones", default=100, type=int, help="max number of bones")
107
+ parser.add_argument('--dataset_path', default="combine_256_updated", type=str, help="data path")
108
+ parser.add_argument("--seed", default=0, type=int)
109
+ parser.add_argument("--precision", default="fp16", type=str)
110
+ parser.add_argument("--batchsize_per_gpu", default=1, type=int)
111
+ parser.add_argument('--pretrained_weights', default=None, type=str)
112
+ parser.add_argument('--save_name', default="infer_results", type=str)
113
+ parser.add_argument("--save_render", default=False, action="store_true", help="save rendering results of mesh with skel")
114
+ parser.add_argument("--apply_marching_cubes", default=False, action="store_true")
115
+ parser.add_argument("--octree_depth", default=7, type=int)
116
+ parser.add_argument("--hier_order", default=False, action="store_true")
117
+
118
+ args = parser.parse_args()
119
+ return args
120
+
121
+ if __name__ == "__main__":
122
+ args = get_args()
123
+
124
+ output_dir = f'{args.output_dir}/{args.save_name}'
125
+ os.makedirs(output_dir, exist_ok=True)
126
+ save_args(args, output_dir)
127
+
128
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
129
+ accelerator = Accelerator(
130
+ kwargs_handlers=[kwargs],
131
+ mixed_precision=args.precision,
132
+ )
133
+
134
+ model = SkeletonGPT(args).cuda()
135
+
136
+ if args.pretrained_weights is not None:
137
+ pkg = torch.load(args.pretrained_weights, map_location=torch.device("cpu"))
138
+ model.load_state_dict(pkg["model"])
139
+ else:
140
+ raise ValueError("Pretrained weights must be provided.")
141
+ model.eval()
142
+ set_seed(args.seed)
143
+
144
+ # create dataset
145
+ if args.input_dir is not None:
146
+ input_list = sorted(os.listdir(args.input_dir))
147
+ input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.stl')]
148
+ dataset = Dataset(input_list, args.input_pc_num, args.apply_marching_cubes, args.octree_depth, output_dir)
149
+ elif args.input_path is not None:
150
+ dataset = Dataset([args.input_path], args.input_pc_num, args.apply_marching_cubes, args.octree_depth, output_dir)
151
+ else:
152
+ raise ValueError("input_dir or input_path must be provided.")
153
+
154
+ dataloader = torch.utils.data.DataLoader(
155
+ dataset,
156
+ batch_size= 1,
157
+ drop_last = False,
158
+ shuffle = False,
159
+ )
160
+
161
+ dataloader, model = accelerator.prepare(dataloader, model)
162
+
163
+ for curr_iter, batch_data_label in tqdm(enumerate(dataloader), total=len(dataloader)):
164
+ with accelerator.autocast():
165
+ pred_bone_coords = model.generate(batch_data_label)
166
+
167
+ # determine the output file name
168
+ file_name = os.path.basename(batch_data_label['file_name'][0])
169
+ pred_skel_filename = os.path.join(output_dir, f'{file_name}_skel.obj')
170
+ pred_rig_filename = os.path.join(output_dir, f"{file_name}_pred.txt")
171
+ mesh_filename = os.path.join(output_dir, f"{file_name}_mesh.obj")
172
+
173
+ transform_params = batch_data_label['transform_params'][0].cpu().numpy()
174
+ trans = transform_params[:3]
175
+ scale = transform_params[3]
176
+ pc_trans = transform_params[4:7]
177
+ pc_scale = transform_params[7]
178
+ vertices = batch_data_label['vertices'][0].cpu().numpy()
179
+ faces = batch_data_label['faces'][0].cpu().numpy()
180
+
181
+ skeleton = pred_bone_coords[0].cpu().numpy()
182
+ pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
183
+
184
+ # Post process: merge duplicate or nearby joints and deduplicate bones.
185
+ if args.hier_order:
186
+ pred_root_index = pred_bones[0][0]
187
+ pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones, root_index=pred_root_index)
188
+ else:
189
+ pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones)
190
+ pred_root_index = None
191
+
192
+ # when save rig to txt, denormalize the skeletons to the same scale with input meshes
193
+ pred_joints_denorm = pred_joints * pc_scale + pc_trans # first align with point cloud
194
+ pred_joints_denorm = pred_joints_denorm / scale + trans # then align with original mesh
195
+
196
+ save_skeleton_to_txt(pred_joints_denorm, pred_bones, pred_root_index, args.hier_order, vertices, pred_rig_filename)
197
+
198
+ # save skeletons
199
+ if args.hier_order:
200
+ save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, pred_root_index, use_cone=True)
201
+ else:
202
+ save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, use_cone=False)
203
+
204
+ # when saving mesh and rendering, use normalized vertices (-0.5,0.5)
205
+ vertices_norm = (vertices - trans) * scale
206
+ vertices_norm = (vertices_norm - pc_trans) / pc_scale
207
+ save_mesh(vertices_norm, faces, mesh_filename)
208
+
209
+ # render mesh w/ skeleton
210
+ if args.save_render:
211
+ if args.hier_order:
212
+ render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred', root_idx=pred_root_index)
213
+ else:
214
+ render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred')
demo.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=0 python demo.py --input_dir ./examples \
2
+ --pretrained_weights skeleton_ckpt/checkpoint_trainonv2_hier.pth \
3
+ --save_name infer_results_demo_hier --input_pc_num 8192 \
4
+ --save_render --apply_marching_cubes --hier_order
download.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ file_path = hf_hub_download(
4
+ repo_id="Maikou/Michelangelo",
5
+ filename="checkpoints/aligned_shape_latents/shapevae-256.ckpt",
6
+ local_dir="third_partys/Michelangelo"
7
+ )
8
+
9
+ file_path = hf_hub_download(
10
+ repo_id="Seed3D/MagicArticulate",
11
+ filename="skeleton_ckpt/checkpoint_trainonv2_hier.pth",
12
+ local_dir=""
13
+ )
14
+
15
+ file_path = hf_hub_download(
16
+ repo_id="Seed3D/MagicArticulate",
17
+ filename="skeleton_ckpt/checkpoint_trainonv2_spatial.pth",
18
+ local_dir=""
19
+ )
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #trimesh==4.2.3
2
+ #accelerate==0.28.0
3
+ #mesh2sdf==1.1.0
4
+ #transformers==4.39.3
5
+ #numpy==1.26.4
6
+ #pyrender==0.1.45
7
+ #tqdm
8
+ #opencv-python==4.9.0.80
9
+ #omegaconf==2.3.0
10
+ #einops==0.7.0
11
+ ##======= HF===================
12
+
13
+ # MagicArticulate Requirements for Gradio Demo
14
+ # Compatible with CUDA 11.8 and Python 3.10
15
+ --extra-index-url https://download.pytorch.org/whl/cu118
16
+ torch==2.1.1
17
+ torchvision==0.16.1
18
+ torchaudio==2.1.1
19
+
20
+ # Gradio for web interface
21
+ gradio==4.44.0
22
+
23
+ # 3D mesh processing
24
+ trimesh==4.4.3
25
+ accelerate==0.28.0
26
+ mesh2sdf==1.1.0
27
+ transformers==4.39.3
28
+ numpy==1.26.4
29
+ pyrender==0.1.45
30
+ tqdm
31
+ opencv-python==4.9.0.80
32
+ omegaconf==2.3.0
33
+ einops==0.7.0
34
+
35
+ flash-attn==2.6.3
36
+ huggingface_hub
37
+ gradio-client>=1.0.0
skeleton_models/shape_opt.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/buaacyw/MeshAnything
2
+ from transformers import AutoModelForCausalLM, AutoConfig, OPTConfig
3
+ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTLearnedPositionalEmbedding, OPTDecoderLayer
4
+ from typing import List, Optional, Tuple, Union
5
+ from transformers.modeling_outputs import (
6
+ CausalLMOutputWithPast,
7
+ )
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.utils import replace_return_docstrings
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+
14
+ class ShapeOPTConfig(OPTConfig):
15
+ model_type = "shape_opt"
16
+
17
+ class ShapeOPT(OPTForCausalLM):
18
+ config_class = ShapeOPTConfig
19
+ def __init__(self, config: ShapeOPTConfig):
20
+ super(OPTForCausalLM, self).__init__(config)
21
+ self.model = ShapeOPTModel(config)
22
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
23
+ # Initialize weights and apply final processing
24
+ self.post_init()
25
+
26
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="OPTConfig")
27
+ def forward(
28
+ self,
29
+ input_ids: torch.LongTensor = None,
30
+ bone_ids: torch.LongTensor = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ head_mask: Optional[torch.Tensor] = None,
33
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
34
+ inputs_embeds: Optional[torch.FloatTensor] = None,
35
+ labels: Optional[torch.LongTensor] = None,
36
+ use_cache: Optional[bool] = None,
37
+ output_attentions: Optional[bool] = None,
38
+ output_hidden_states: Optional[bool] = None,
39
+ return_dict: Optional[bool] = None,
40
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
41
+ r"""
42
+ Args:
43
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
44
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
45
+ provide it.
46
+
47
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
48
+ [`PreTrainedTokenizer.__call__`] for details.
49
+
50
+ [What are input IDs?](../glossary#input-ids)
51
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
52
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
53
+
54
+ - 1 for tokens that are **not masked**,
55
+ - 0 for tokens that are **masked**.
56
+
57
+ [What are attention masks?](../glossary#attention-mask)
58
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
59
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
60
+
61
+ - 1 indicates the head is **not masked**,
62
+ - 0 indicates the head is **masked**.
63
+
64
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
65
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
66
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
67
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
68
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
69
+
70
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
71
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
72
+
73
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
74
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
75
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
76
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
77
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
78
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
79
+ than the model's internal embedding lookup matrix.
80
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
81
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
82
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
83
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
84
+ use_cache (`bool`, *optional*):
85
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
86
+ (see `past_key_values`).
87
+ output_attentions (`bool`, *optional*):
88
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
89
+ returned tensors for more detail.
90
+ output_hidden_states (`bool`, *optional*):
91
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
92
+ for more detail.
93
+ return_dict (`bool`, *optional*):
94
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
95
+
96
+ Returns:
97
+
98
+ Example:
99
+
100
+ ```python
101
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
102
+
103
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
104
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
105
+
106
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
107
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
108
+
109
+ >>> # Generate
110
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
111
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
112
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
113
+ ```"""
114
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
115
+ output_hidden_states = (
116
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
117
+ )
118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
119
+
120
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
121
+ outputs = self.model.decoder(
122
+ input_ids = input_ids,
123
+ bone_ids = bone_ids,
124
+ attention_mask=attention_mask,
125
+ head_mask=head_mask,
126
+ past_key_values=past_key_values,
127
+ inputs_embeds=inputs_embeds,
128
+ use_cache=use_cache,
129
+ output_attentions=output_attentions,
130
+ output_hidden_states=output_hidden_states,
131
+ return_dict=return_dict,
132
+ )
133
+
134
+ logits = self.lm_head(outputs[0]).contiguous()
135
+
136
+ loss = None
137
+ if labels is not None:
138
+ # move labels to correct device to enable model parallelism
139
+ labels = labels.to(logits.device)
140
+ # Shift so that tokens < n predict n
141
+ shift_logits = logits[..., :-1, :].contiguous()
142
+ shift_labels = labels[..., 1:].contiguous()
143
+ # Flatten the tokens
144
+ loss_fct = CrossEntropyLoss()
145
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
146
+
147
+ if not return_dict:
148
+ output = (logits,) + outputs[1:]
149
+ return (loss,) + output if loss is not None else output
150
+
151
+ return CausalLMOutputWithPast(
152
+ loss=loss,
153
+ logits=logits,
154
+ past_key_values=outputs.past_key_values,
155
+ hidden_states=outputs.hidden_states,
156
+ attentions=outputs.attentions,
157
+ )
158
+
159
+ class ShapeOPTModel(OPTModel):
160
+ config_class = ShapeOPTConfig
161
+ def __init__(self, config: ShapeOPTConfig):
162
+ super(OPTModel,self).__init__(config)
163
+ self.decoder = ShapeOPTDecoder(config)
164
+ # Initialize weights and apply final processing
165
+ self.post_init()
166
+
167
+ class ShapeOPTDecoder(OPTDecoder):
168
+ config_class = ShapeOPTConfig
169
+ def __init__(self, config: ShapeOPTConfig):
170
+ super(OPTDecoder,self).__init__(config)
171
+ self.config = config
172
+ self.dropout = config.dropout
173
+ self.layerdrop = config.layerdrop
174
+ self.padding_idx = config.pad_token_id
175
+ self.vocab_size = config.vocab_size
176
+ assert config.word_embed_proj_dim == config.hidden_size
177
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
178
+ self.hidden_size = config.hidden_size
179
+ self.word_embed_proj_dim = config.word_embed_proj_dim
180
+ self.n_discrete_size = config.n_discrete_size
181
+
182
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
183
+ self.token_embed_positions = OPTBonePositionalEmbedding(config.bone_per_token+3, config.word_embed_proj_dim)
184
+
185
+ self.bone_per_token = config.bone_per_token
186
+ self.cond_length = config.cond_length
187
+ self.cond_embed = nn.Embedding(2, config.word_embed_proj_dim)
188
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
189
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
190
+ # see https://github.com/facebookresearch/metaseq/pull/164
191
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
192
+ self.final_layer_norm = nn.LayerNorm(
193
+ config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
194
+ )
195
+ else:
196
+ self.final_layer_norm = None
197
+
198
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
199
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
200
+
201
+ self.gradient_checkpointing = False
202
+ # Initialize weights and apply final processing
203
+ self.post_init()
204
+
205
+ def forward(
206
+ self,
207
+ input_ids: torch.LongTensor = None,
208
+ bone_ids: torch.LongTensor = None,
209
+ attention_mask: Optional[torch.Tensor] = None,
210
+ head_mask: Optional[torch.Tensor] = None,
211
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
212
+ inputs_embeds: Optional[torch.FloatTensor] = None,
213
+ use_cache: Optional[bool] = None,
214
+ output_attentions: Optional[bool] = None,
215
+ output_hidden_states: Optional[bool] = None,
216
+ return_dict: Optional[bool] = None,
217
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
218
+ r"""
219
+ Args:
220
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
221
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
222
+ provide it.
223
+
224
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
225
+ [`PreTrainedTokenizer.__call__`] for details.
226
+
227
+ [What are input IDs?](../glossary#input-ids)
228
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
229
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
230
+
231
+ - 1 for tokens that are **not masked**,
232
+ - 0 for tokens that are **masked**.
233
+
234
+ [What are attention masks?](../glossary#attention-mask)
235
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
236
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
237
+
238
+ - 1 indicates the head is **not masked**,
239
+ - 0 indicates the head is **masked**.
240
+
241
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
242
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
243
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
244
+
245
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
246
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
247
+
248
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
249
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
250
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
251
+
252
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
253
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
254
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
255
+ than the model's internal embedding lookup matrix.
256
+ output_attentions (`bool`, *optional*):
257
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
258
+ returned tensors for more detail.
259
+ output_hidden_states (`bool`, *optional*):
260
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
261
+ for more detail.
262
+ return_dict (`bool`, *optional*):
263
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
264
+ """
265
+ # OPT Decoder
266
+ # print("used my Trans")
267
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
268
+ output_hidden_states = (
269
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
270
+ )
271
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
272
+
273
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
274
+ # Transformer Decoder
275
+ if input_ids is not None and inputs_embeds is not None: # when training
276
+ pass
277
+ elif input_ids is not None: # when inference
278
+ assert not self.training
279
+ input_shape = input_ids.size()
280
+ input_ids = input_ids.view(-1, input_shape[-1])
281
+ inputs_embeds = self.embed_tokens(input_ids)
282
+ bone_embeds = self.token_embed_positions(attention_mask[:, self.cond_length:], bone_ids, input_ids,
283
+ self.bone_per_token)
284
+ inputs_embeds += bone_embeds
285
+ cond_embed_query = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=inputs_embeds.device,
286
+ dtype=inputs_embeds.dtype).long()
287
+ inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
288
+
289
+ elif inputs_embeds is not None: # when generate first skeleton token
290
+ assert not self.training
291
+ total_length = inputs_embeds.shape[1]
292
+ cond_embed_query = torch.zeros((inputs_embeds.shape[0], total_length), device=inputs_embeds.device,
293
+ dtype=inputs_embeds.dtype).long()
294
+ inputs_embeds = inputs_embeds + self.cond_embed(cond_embed_query)
295
+ else:
296
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
297
+
298
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
299
+ # embed positions
300
+ if self._use_flash_attention_2:
301
+ # 2d mask is passed through the layers
302
+ assert attention_mask is not None
303
+ causal_attention_mask = attention_mask if 0 in attention_mask else None
304
+ else:
305
+ raise ValueError("Only flash_attention_2 is supported")
306
+
307
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
308
+
309
+ hidden_states = inputs_embeds + pos_embeds
310
+
311
+ # decoder layers
312
+ all_hidden_states = () if output_hidden_states else None
313
+ all_self_attns = () if output_attentions else None
314
+ next_decoder_cache = () if use_cache else None
315
+
316
+ # check if head_mask has a correct number of layers specified if desired
317
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
318
+ if attn_mask is not None:
319
+ if attn_mask.size()[0] != (len(self.layers)):
320
+ raise ValueError(
321
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
322
+ f" {head_mask.size()[0]}."
323
+ )
324
+
325
+ for idx, decoder_layer in enumerate(self.layers):
326
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
327
+ if output_hidden_states:
328
+ all_hidden_states += (hidden_states,)
329
+
330
+ if self.training:
331
+ dropout_probability = torch.rand([])
332
+ if dropout_probability < self.layerdrop:
333
+ continue
334
+
335
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
336
+
337
+ if self.gradient_checkpointing and self.training:
338
+ layer_outputs = self._gradient_checkpointing_func(
339
+ decoder_layer.__call__,
340
+ hidden_states,
341
+ causal_attention_mask,
342
+ head_mask[idx] if head_mask is not None else None,
343
+ None,
344
+ output_attentions,
345
+ use_cache,
346
+ )
347
+ else:
348
+ layer_outputs = decoder_layer(
349
+ hidden_states,
350
+ attention_mask=causal_attention_mask,
351
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
352
+ past_key_value=past_key_value,
353
+ output_attentions=output_attentions,
354
+ use_cache=use_cache,
355
+ )
356
+
357
+ hidden_states = layer_outputs[0]
358
+
359
+ if use_cache:
360
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
361
+
362
+ if output_attentions:
363
+ all_self_attns += (layer_outputs[1],)
364
+
365
+ if self.final_layer_norm is not None:
366
+ hidden_states = self.final_layer_norm(hidden_states)
367
+
368
+ # add hidden states from the last decoder layer
369
+ if output_hidden_states:
370
+ all_hidden_states += (hidden_states,)
371
+
372
+ next_cache = next_decoder_cache if use_cache else None
373
+ if not return_dict:
374
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
375
+ return BaseModelOutputWithPast(
376
+ last_hidden_state=hidden_states,
377
+ past_key_values=next_cache,
378
+ hidden_states=all_hidden_states,
379
+ attentions=all_self_attns,
380
+ )
381
+
382
+ class OPTBonePositionalEmbedding(nn.Embedding):
383
+ """
384
+ This module learns positional embeddings up to a fixed maximum size.
385
+ """
386
+
387
+ def __init__(self, num_embeddings: int, embedding_dim: int):
388
+ super().__init__(num_embeddings, embedding_dim)
389
+
390
+ def forward(self, attention_mask=None, bone_ids = None, input_ids = None, bone_per_token = None):
391
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
392
+ if bone_ids is not None:
393
+ return super().forward(bone_ids)
394
+
395
+ assert input_ids.shape[1] == 1
396
+ idx_in_extra = torch.isin(input_ids, torch.LongTensor([0, 1, 2]).to(input_ids.device))
397
+ cur_ids = input_ids.clone().detach()
398
+
399
+ cur_index = (attention_mask.sum(dim=1, keepdim=True) - 2) % bone_per_token + 3
400
+ cur_ids[~idx_in_extra]=cur_index[~idx_in_extra]
401
+
402
+ return super().forward(cur_ids)
403
+
404
+ AutoConfig.register("shape_opt", ShapeOPTConfig)
405
+ AutoModelForCausalLM.register(ShapeOPTConfig, ShapeOPT)
406
+
skeleton_models/skeletongen.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torch import nn
16
+ from transformers import AutoModelForCausalLM
17
+ from third_partys.Michelangelo.encode import load_model
18
+ from skeleton_models.shape_opt import ShapeOPTConfig
19
+
20
+ def undiscretize(t, low, high, num_discrete):
21
+ assert (t >= 0).all() and (t <= num_discrete-1).all()
22
+ assert high > low
23
+ t = t.float()
24
+ t /= num_discrete
25
+ t = t * (high - low) + low
26
+ assert (t < high).all() and (t >= low).all()
27
+ return t
28
+
29
+ class SkeletonGPT(nn.Module):
30
+ def __init__(self, args):
31
+ super().__init__()
32
+
33
+ self.args = args
34
+ self.point_encoder = load_model()
35
+
36
+ self.cond_length = 257
37
+ self.cond_dim = 768
38
+
39
+ self.n_discrete_size = args.n_discrete_size
40
+
41
+ self.bone_per_token = 6 # (2 joints per bone)
42
+ self.max_length = int(args.n_max_bones * self.bone_per_token + 2 + self.cond_length)
43
+ self.pad_id = -1
44
+
45
+ self.coor_continuous_range = (-0.5, 0.5)
46
+
47
+ vocab_size = self.n_discrete_size + 3 # 3 for bos, eos, pad
48
+ self.config = ShapeOPTConfig.from_pretrained(
49
+ args.llm,
50
+ n_positions=self.max_length,
51
+ max_position_embeddings=self.max_length,
52
+ vocab_size = vocab_size,
53
+ _attn_implementation="flash_attention_2"
54
+ )
55
+
56
+ self.bos_token_id = 0
57
+ self.eos_token_id = 1
58
+ self.pad_token_id = 2
59
+
60
+ self.config.bos_token_id = self.bos_token_id
61
+ self.config.eos_token_id = self.eos_token_id
62
+ self.config.pad_token_id = self.pad_token_id
63
+ self.config._attn_implementation ="flash_attention_2"
64
+ self.config.n_discrete_size = self.n_discrete_size
65
+ self.config.bone_per_token = self.bone_per_token
66
+ self.config.cond_length = self.cond_length
67
+
68
+ self.config.word_embed_proj_dim = self.config.hidden_size # 1024
69
+
70
+
71
+ self.transformer = AutoModelForCausalLM.from_config(
72
+ config=self.config, attn_implementation="flash_attention_2")
73
+
74
+ self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
75
+ self.cond_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
76
+
77
+ self.eval()
78
+
79
+ def detokenize(self, input_ids):
80
+ # input_ids: torch.Tensor of shape (batch_size, seq_length)
81
+ batch_size = input_ids.size(0)
82
+
83
+ continuous_coors_list = []
84
+ num_bones_list = []
85
+
86
+ for i in range(batch_size):
87
+ cur_ids = input_ids[i] # Shape: (seq_length,)
88
+
89
+ # Remove padding tokens
90
+ cur_ids = cur_ids[cur_ids != self.pad_id] # Shape: (effective_seq_length,)
91
+
92
+ # Check if length is a multiple of 6 (2 joints * 3 coordinates)
93
+ if cur_ids.numel() % 6 != 0:
94
+ return None
95
+ # raise ValueError(f"Invalid length of input_ids in sample {i}. It should be a multiple of 6.")
96
+
97
+ num_bones = cur_ids.numel() // 6
98
+ num_bones_list.append(num_bones)
99
+
100
+ # Reshape into (num_bones, 6)
101
+ bone_coords = cur_ids.view(num_bones, 6) # Shape: (num_bones, 6)
102
+
103
+ # Undiscretize the coordinates
104
+ # Initialize tensor to hold bone coordinates
105
+ bones_coors = torch.zeros((num_bones, 2, 3), dtype=torch.float16, device=cur_ids.device)
106
+
107
+ for j in range(num_bones):
108
+ bone_coord = bone_coords[j] # Shape: (6,)
109
+
110
+ # Split into two joints
111
+ joint1_ids = bone_coord[:3]
112
+ joint2_ids = bone_coord[3:]
113
+
114
+ # Undiscretize joint coordinates
115
+ joint1_coords = undiscretize(joint1_ids, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
116
+ joint2_coords = undiscretize(joint2_ids, self.coor_continuous_range[0], self.coor_continuous_range[1], self.n_discrete_size)
117
+
118
+ # Assign to bones_coors
119
+ bones_coors[j, 0, :] = joint1_coords
120
+ bones_coors[j, 1, :] = joint2_coords
121
+
122
+ continuous_coors_list.append(bones_coors)
123
+
124
+ max_num_bones = max(num_bones_list)
125
+
126
+ # Initialize the continuous_coors tensor with NaNs
127
+ continuous_coors = torch.full(
128
+ (batch_size, max_num_bones, 2, 3),
129
+ float('nan'),
130
+ dtype=torch.float16,
131
+ device=input_ids.device
132
+ )
133
+
134
+ # Place the bones_coors into continuous_coors
135
+ for i in range(batch_size):
136
+ num_bones = num_bones_list[i]
137
+ continuous_coors[i, :num_bones, :, :] = continuous_coors_list[i]
138
+
139
+ return continuous_coors # Shape: (batch_size, max_num_bones, 2, 3)
140
+
141
+
142
+ # def forward(self, data_dict: dict, is_eval: bool = False) -> dict:
143
+ # return self.generate(data_dict)
144
+
145
+ def process_point_feature(self, point_feature):
146
+
147
+ encode_feature = torch.zeros(self.args.batchsize_per_gpu, self.cond_length, self.config.word_embed_proj_dim,
148
+ device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
149
+ encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
150
+ shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
151
+
152
+ encode_feature[:, 1:] = self.cond_proj(shape_latents)
153
+
154
+ return encode_feature
155
+
156
+ @torch.no_grad()
157
+ def generate(self, data_dict) -> dict:
158
+
159
+ point_feature = self.point_encoder.encode_latents(data_dict["pc_normal"])
160
+ processed_point_feature = self.process_point_feature(point_feature=point_feature)
161
+ generate_length = self.max_length - self.cond_length
162
+ net_device = next(self.parameters()).device
163
+ outputs = torch.ones(self.args.batchsize_per_gpu, generate_length).long().to(net_device) * self.eos_token_id
164
+ # batch x ntokens
165
+ if self.args.num_beams is not None and "pc_normal" in data_dict:
166
+ results = self.transformer.generate(
167
+ inputs_embeds=processed_point_feature,
168
+ max_new_tokens=generate_length, # all faces plus two
169
+ num_beams=self.args.num_beams,
170
+ bos_token_id=self.bos_token_id,
171
+ eos_token_id=self.eos_token_id,
172
+ pad_token_id=self.pad_token_id,
173
+ )
174
+ else:
175
+ results = self.transformer.generate(
176
+ inputs_embeds = processed_point_feature,
177
+ max_new_tokens = generate_length, # all faces plus two
178
+ do_sample=True,
179
+ top_k=50,
180
+ top_p=0.95,
181
+ bos_token_id = self.bos_token_id,
182
+ eos_token_id = self.eos_token_id,
183
+ pad_token_id = self.pad_token_id,
184
+ )
185
+ assert results.shape[1] <= generate_length # B x ID bos is not included since it's predicted
186
+ outputs[:, :results.shape[1]] = results
187
+ # batch x ntokens ====> batch x ntokens x D
188
+ outputs = outputs[:, 1: -1] # eos and bos removed
189
+
190
+ outputs[outputs == self.bos_token_id] = self.pad_id
191
+ outputs[outputs == self.eos_token_id] = self.pad_id
192
+ outputs[outputs == self.pad_token_id] = self.pad_id
193
+
194
+ outputs[outputs != self.pad_id] -= 3
195
+
196
+ gen_joints = self.detokenize(outputs)
197
+
198
+ return gen_joints
utils/eval_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/zhan-xu/RigNet
2
+
3
+ import numpy as np
4
+
5
+ ##### for quantitative calculation
6
+ def chamfer_dist(pt1, pt2):
7
+ pt1 = pt1[np.newaxis, :, :]
8
+ pt2 = pt2[:, np.newaxis, :]
9
+ dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
10
+ min_left = np.mean(np.min(dist, axis=0))
11
+ min_right = np.mean(np.min(dist, axis=1))
12
+ return (min_left + min_right) / 2
13
+
14
+ def oneway_chamfer(pt_src, pt_dst):
15
+ pt1 = pt_src[np.newaxis, :, :]
16
+ pt2 = pt_dst[:, np.newaxis, :]
17
+ dist = np.sqrt(np.sum((pt1 - pt2) ** 2, axis=2))
18
+ avg_dist = np.mean(np.min(dist, axis=0))
19
+ return avg_dist
20
+
21
+ def joint2bone_chamfer_dist(joints1, bones1, joints2, bones2):
22
+ bone_sample_1 = sample_skel(joints1, bones1)
23
+ bone_sample_2 = sample_skel(joints2, bones2)
24
+ dist1 = oneway_chamfer(joints1, bone_sample_2)
25
+ dist2 = oneway_chamfer(joints2, bone_sample_1)
26
+ return (dist1 + dist2) / 2
27
+
28
+ def bone2bone_chamfer_dist(joints1, bones1, joints2, bones2):
29
+ bone_sample_1 = sample_skel(joints1, bones1)
30
+ bone_sample_2 = sample_skel(joints2, bones2)
31
+ return chamfer_dist(bone_sample_1, bone_sample_2)
32
+
33
+ def sample_bone(p_pos, ch_pos):
34
+ ray = ch_pos - p_pos
35
+
36
+ bone_length = np.linalg.norm(p_pos - ch_pos)
37
+ num_step = np.round(bone_length / 0.005).astype(int)
38
+ i_step = np.arange(0, num_step + 1)
39
+ unit_step = ray / (num_step + 1e-30)
40
+ unit_step = np.repeat(unit_step[np.newaxis, :], num_step + 1, axis=0)
41
+ res = p_pos + unit_step * i_step[:, np.newaxis]
42
+ return res
43
+
44
+ def sample_skel(joints, bones):
45
+ bone_sample = []
46
+ for parent_idx, child_idx in bones:
47
+ p_pos = joints[parent_idx]
48
+ ch_pos = joints[child_idx]
49
+ res = sample_bone(p_pos, ch_pos)
50
+ bone_sample.append(res)
51
+
52
+ if bone_sample:
53
+ bone_sample = np.concatenate(bone_sample, axis=0)
54
+ else:
55
+ bone_sample = np.empty((0, 3))
56
+
57
+ return bone_sample
utils/mesh_to_pc.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/buaacyw/MeshAnything
2
+ import mesh2sdf.core
3
+ import numpy as np
4
+ import skimage.measure
5
+ import trimesh
6
+ import time
7
+ from typing import List, Tuple
8
+
9
+ class MeshProcessor:
10
+ """A class to handle mesh normalization, watertight conversion and point cloud sampling."""
11
+
12
+ @staticmethod
13
+ def normalize_mesh_vertices(vertices: np.ndarray, scaling_factor: float = 0.95) -> Tuple[np.ndarray, np.ndarray, float]:
14
+ """
15
+ Normalize mesh vertices to be centered at origin and scaled appropriately.
16
+ """
17
+ min_bounds = vertices.min(axis=0)
18
+ max_bounds = vertices.max(axis=0)
19
+
20
+ center = (min_bounds + max_bounds) * 0.5
21
+ max_dimension = (max_bounds - min_bounds).max()
22
+ scale = 2.0 * scaling_factor / max_dimension
23
+
24
+ normalized_vertices = (vertices - center) * scale
25
+ return normalized_vertices, center, scale
26
+
27
+ @staticmethod
28
+ def convert_to_watertight(mesh: trimesh.Trimesh, octree_depth: int = 7) -> trimesh.Trimesh:
29
+ """
30
+ Convert to watertight using mesh2sdf and marching cubes.
31
+ """
32
+ grid_size = 2 ** octree_depth
33
+ iso_level = 2 / grid_size
34
+
35
+ # Normalize vertices for SDF computation
36
+ normalized_vertices, original_center, original_scale = MeshProcessor.normalize_mesh_vertices(mesh.vertices)
37
+
38
+ # Compute signed distance field
39
+ sdf = mesh2sdf.core.compute(normalized_vertices, mesh.faces, size=grid_size)
40
+
41
+ # Run marching cubes algorithm
42
+ vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), iso_level)
43
+
44
+ # Transform vertices back to original coordinate system
45
+ vertices = vertices / grid_size * 2 - 1 # Map to [-1, 1] range
46
+ vertices = vertices / original_scale + original_center
47
+
48
+ # Create new watertight mesh
49
+ watertight_mesh = trimesh.Trimesh(vertices, faces, normals=normals)
50
+ return watertight_mesh
51
+
52
+ @staticmethod
53
+ def convert_meshes_to_point_clouds(
54
+ meshes: List[trimesh.Trimesh],
55
+ points_per_mesh: int = 8192,
56
+ apply_marching_cubes: bool = False,
57
+ octree_depth: int = 7
58
+ ) -> List[np.ndarray]:
59
+ """
60
+ Process a list of meshes into point clouds with normals.
61
+ """
62
+ point_clouds_with_normals = []
63
+ processed_meshes = []
64
+
65
+ for mesh in meshes:
66
+ # Optionally convert to watertight mesh
67
+ if apply_marching_cubes:
68
+ start_time = time.time()
69
+ mesh = MeshProcessor.convert_to_watertight(mesh, octree_depth=octree_depth)
70
+ processing_time = time.time() - start_time
71
+ print(f"Marching cubes complete! Time: {processing_time:.2f}s")
72
+
73
+ # Store processed mesh
74
+ processed_meshes.append(mesh)
75
+
76
+ # Sample points and get corresponding face normals
77
+ points, face_indices = mesh.sample(points_per_mesh, return_index=True)
78
+ point_normals = mesh.face_normals[face_indices]
79
+
80
+ # Combine points and normals
81
+ points_with_normals = np.concatenate([points, point_normals], axis=-1, dtype=np.float16)
82
+ point_clouds_with_normals.append(points_with_normals)
83
+
84
+ return point_clouds_with_normals
utils/save_utils.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import numpy as np
16
+ import cv2
17
+ import json
18
+ import trimesh
19
+
20
+ from collections import deque, defaultdict
21
+ from scipy.cluster.hierarchy import linkage, fcluster
22
+ from scipy.spatial.distance import cdist
23
+
24
+ from data_utils.pyrender_wrapper import PyRenderWrapper
25
+ from data_utils.data_loader import DataLoader
26
+
27
+ def save_mesh(vertices, faces, filename):
28
+
29
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
30
+ mesh.export(filename, file_type='obj')
31
+
32
+ def pred_joints_and_bones(bone_coor):
33
+ """
34
+ get joints (j,3) and bones (b,2) from (b,2,3), preserve the parent-child relationship
35
+ """
36
+ parent_coords = bone_coor[:, 0, :] # (b, 3)
37
+ child_coords = bone_coor[:, 1, :] # (b, 3)
38
+
39
+ all_coords = np.vstack([parent_coords, child_coords]) # (2b, 3)
40
+ pred_joints, indices = np.unique(all_coords, axis=0, return_inverse=True)
41
+
42
+ b = bone_coor.shape[0]
43
+ parent_indices = indices[:b]
44
+ child_indices = indices[b:]
45
+
46
+ pred_bones = np.column_stack([parent_indices, child_indices])
47
+
48
+ valid_bones = pred_bones[parent_indices != child_indices]
49
+
50
+ return pred_joints, valid_bones
51
+
52
+ def find_connected_components(joints, bones):
53
+ """Find connected components in the skeleton graph."""
54
+ n_joints = len(joints)
55
+ graph = defaultdict(list)
56
+
57
+ # Build adjacency list
58
+ for parent, child in bones:
59
+ graph[parent].append(child)
60
+ graph[child].append(parent)
61
+
62
+ visited = [False] * n_joints
63
+ components = []
64
+
65
+ for i in range(n_joints):
66
+ if not visited[i]:
67
+ component = []
68
+ queue = deque([i])
69
+ visited[i] = True
70
+
71
+ while queue:
72
+ node = queue.popleft()
73
+ component.append(node)
74
+
75
+ for neighbor in graph[node]:
76
+ if not visited[neighbor]:
77
+ visited[neighbor] = True
78
+ queue.append(neighbor)
79
+
80
+ components.append(component)
81
+
82
+ return components
83
+
84
+ def ensure_skeleton_connectivity(joints, bones, root_index=None, merge_distance_threshold=0.01):
85
+ """
86
+ Ensure skeleton is fully connected.
87
+ - If distance < merge_distance_threshold: merge joints
88
+ - If distance >= merge_distance_threshold: connect with bone
89
+ """
90
+ current_joints = joints.copy()
91
+ current_bones = list(bones)
92
+ current_root = root_index
93
+
94
+ iteration = 0
95
+ while True:
96
+ components = find_connected_components(current_joints, current_bones)
97
+ if len(components) == 1:
98
+ # print("Successfully ensured skeleton connectivity")
99
+ break
100
+
101
+ # Find the globally closest pair of components
102
+ min_distance = float('inf')
103
+ best_pair = None
104
+
105
+ for i in range(len(components)):
106
+ for j in range(i + 1, len(components)):
107
+ comp1_joints = current_joints[components[i]]
108
+ comp2_joints = current_joints[components[j]]
109
+
110
+ distances = cdist(comp1_joints, comp2_joints)
111
+ min_idx = np.unravel_index(np.argmin(distances), distances.shape)
112
+ distance = distances[min_idx]
113
+
114
+ if distance < min_distance:
115
+ min_distance = distance
116
+ best_pair = (i, j, components[i][min_idx[0]], components[j][min_idx[1]], min_idx)
117
+
118
+ if best_pair is None:
119
+ print("Warning: Could not find valid component pair to connect")
120
+ break
121
+
122
+ comp1_idx, comp2_idx, joint1_idx, joint2_idx, min_idx = best_pair
123
+
124
+ if min_distance < merge_distance_threshold:
125
+ # Merge the joints
126
+ # print(f"Iteration {iteration + 1}: Merging closest joints {joint1_idx} and {joint2_idx} "
127
+ # f"(distance: {min_distance:.4f})")
128
+
129
+ # Always merge joint2 into joint1
130
+ merge_map = {joint2_idx: joint1_idx}
131
+
132
+ # Update bones
133
+ updated_bones = []
134
+ for parent, child in current_bones:
135
+ new_parent = merge_map.get(parent, parent)
136
+ new_child = merge_map.get(child, child)
137
+ if new_parent != new_child: # Remove self-loops
138
+ updated_bones.append([new_parent, new_child])
139
+
140
+ # Update root
141
+ if current_root == joint2_idx:
142
+ current_root = joint1_idx
143
+
144
+ # Remove the merged joint and update indices
145
+ joint_to_remove = joint2_idx
146
+ mask = np.ones(len(current_joints), dtype=bool)
147
+ mask[joint_to_remove] = False
148
+ current_joints = current_joints[mask]
149
+
150
+ # Create index mapping for remaining joints
151
+ old_to_new = {}
152
+ new_idx = 0
153
+ for old_idx in range(len(mask)):
154
+ if mask[old_idx]:
155
+ old_to_new[old_idx] = new_idx
156
+ new_idx += 1
157
+
158
+ # Update bone indices
159
+ current_bones = [[old_to_new[parent], old_to_new[child]]
160
+ for parent, child in updated_bones
161
+ if parent in old_to_new and child in old_to_new]
162
+
163
+ # Update root index
164
+ if current_root is not None and current_root in old_to_new:
165
+ current_root = old_to_new[current_root]
166
+
167
+ else:
168
+ # Connect with bone
169
+ # print(f"Iteration {iteration + 1}: Connecting closest components with bone {joint1_idx} -> {joint2_idx} "
170
+ # f"(distance: {min_distance:.4f})")
171
+ current_bones.append([joint1_idx, joint2_idx])
172
+
173
+ iteration += 1
174
+
175
+ # prevent infinite loops
176
+ if iteration > len(joints):
177
+ print(f"Warning: Maximum iterations reached ({iteration}), stopping")
178
+ break
179
+
180
+ current_bones = np.array(current_bones) if len(current_bones) > 0 else np.array([]).reshape(0, 2)
181
+
182
+ # Final connectivity verification
183
+ final_components = find_connected_components(current_joints, current_bones)
184
+ if len(final_components) == 1:
185
+ pass
186
+ else:
187
+ print(f"Warning: Still have {len(final_components)} disconnected components after {iteration} iterations")
188
+
189
+ return current_joints, current_bones, current_root
190
+
191
+ def merge_duplicate_joints_and_fix_bones(joints, bones, tolerance=0.0025, root_index=None):
192
+ """
193
+ merge duplicate joints that are within a certain tolerance distance, and fix bones to maintain connectivity.
194
+ Also merge bones that become duplicates after joint merging.
195
+ """
196
+ n_joints = len(joints)
197
+
198
+ # find merge joint groups
199
+ merge_groups = []
200
+ used = [False] * n_joints
201
+
202
+ for i in range(n_joints):
203
+ if used[i]:
204
+ continue
205
+
206
+ # find all joints within tolerance distance to joint i
207
+ group = [i]
208
+ for j in range(i + 1, n_joints):
209
+ if not used[j]:
210
+ dist = np.linalg.norm(joints[i] - joints[j])
211
+ if dist < tolerance:
212
+ group.append(j)
213
+ used[j] = True
214
+
215
+ used[i] = True
216
+ merge_groups.append(group)
217
+
218
+ # if len(group) > 1:
219
+ # print(f"find duplicate joints group: {group}")
220
+
221
+ # build merge map: choose representative joint
222
+ merge_map = {}
223
+ for group in merge_groups:
224
+ if root_index is not None and root_index in group:
225
+ representative = root_index
226
+ else:
227
+ representative = group[0] # else choose the first one as representative
228
+ for joint_idx in group:
229
+ merge_map[joint_idx] = representative
230
+
231
+ # track root joint change
232
+ intermediate_root_index = None
233
+ if root_index is not None:
234
+ intermediate_root_index = merge_map.get(root_index, root_index)
235
+ # if intermediate_root_index != root_index:
236
+ # print(f"root joint index changed from {root_index} to {intermediate_root_index}")
237
+
238
+ # update bones: remove self-loop bones, and merge duplicate bones
239
+ updated_bones = []
240
+
241
+ for parent, child in bones:
242
+ new_parent = merge_map.get(parent, parent)
243
+ new_child = merge_map.get(child, child)
244
+
245
+ if new_parent != new_child: # remove self-loop bones
246
+ updated_bones.append([new_parent, new_child])
247
+
248
+ # remove duplicate bones
249
+ unique_bones = []
250
+ seen_bones = set()
251
+
252
+ for bone in updated_bones:
253
+ bone_key = tuple(bone) # keep the order of [parent, child]
254
+ if bone_key not in seen_bones:
255
+ seen_bones.add(bone_key)
256
+ unique_bones.append(bone)
257
+
258
+ # re-index joints to remove unused joints
259
+ used_joint_indices = set()
260
+ for parent, child in unique_bones:
261
+ used_joint_indices.add(parent)
262
+ used_joint_indices.add(child)
263
+ if intermediate_root_index is not None:
264
+ used_joint_indices.add(intermediate_root_index)
265
+
266
+
267
+ used_joint_indices = sorted(list(used_joint_indices))
268
+
269
+ # new index for used joints
270
+ old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(used_joint_indices)}
271
+
272
+ final_joints = joints[used_joint_indices]
273
+ final_bones = np.array([[old_to_new[parent], old_to_new[child]]
274
+ for parent, child in unique_bones])
275
+
276
+ final_root_index = None
277
+ if intermediate_root_index is not None:
278
+ final_root_index = old_to_new[intermediate_root_index]
279
+ if root_index is not None and final_root_index != root_index:
280
+ print(f"final root index: {root_index} -> {final_root_index}")
281
+
282
+ removed_joints = n_joints - len(final_joints)
283
+ removed_bones = len(bones) - len(final_bones)
284
+
285
+ # print
286
+ # if removed_joints > 0 or removed_bones > 0:
287
+ # print(f"merge results:")
288
+ # print(f" joint number: {n_joints} -> {len(final_joints)} (remove {removed_joints})")
289
+ # print(f" bone number: {len(bones)} -> {len(final_bones)} (remove {removed_bones})")
290
+
291
+ # Ensure skeleton connectivity with relaxed threshold
292
+ final_joints, final_bones, final_root_index = ensure_skeleton_connectivity(
293
+ final_joints, final_bones, final_root_index,
294
+ merge_distance_threshold=tolerance*8 # More relaxed threshold for connectivity
295
+ )
296
+
297
+ if root_index is not None:
298
+ return final_joints, final_bones, final_root_index
299
+ else:
300
+ return final_joints, final_bones
301
+
302
+
303
+ def save_skeleton_to_txt(pred_joints, pred_bones, pred_root_index, hier_order, vertices, filename='skeleton.txt'):
304
+ """
305
+ save skeleton to txt file, the format follows Rignet (joints, root, hier)
306
+
307
+ if hier_order: the first joint index in bone is root joint index, and parent-child relationship is established in bones.
308
+ else: we set the joint nearest to the mesh center as the root joint, and then build hierarchy starting from root.
309
+ """
310
+
311
+ num_joints = pred_joints.shape[0]
312
+
313
+ # assign joint names
314
+ joint_names = [f'joint{i}' for i in range(num_joints)]
315
+
316
+ adjacency = defaultdict(list)
317
+ for bone in pred_bones:
318
+ idx_a, idx_b = bone
319
+ adjacency[idx_a].append(idx_b)
320
+ adjacency[idx_b].append(idx_a)
321
+
322
+ # find root joint
323
+ if hier_order:
324
+ root_idx = pred_root_index
325
+ else:
326
+ centroid = np.mean(vertices, axis=0)
327
+ distances = np.linalg.norm(pred_joints - centroid, axis=1)
328
+ root_idx = np.argmin(distances)
329
+
330
+ root_name = joint_names[root_idx]
331
+
332
+ # build hierarchy
333
+ parent_map = {}
334
+
335
+ if hier_order:
336
+ visited = set()
337
+
338
+ for parent_idx, child_idx in pred_bones:
339
+ if child_idx not in parent_map:
340
+ parent_map[child_idx] = parent_idx
341
+ visited.add(child_idx)
342
+ visited.add(parent_idx)
343
+
344
+ parent_map[root_idx] = None
345
+
346
+ else:
347
+ visited = set([root_idx])
348
+ queue = deque([root_idx])
349
+ parent_map[root_idx] = None
350
+
351
+ while queue:
352
+ current_idx = queue.popleft()
353
+ for neighbor_idx in adjacency[current_idx]:
354
+ if neighbor_idx not in visited:
355
+ parent_map[neighbor_idx] = current_idx
356
+ visited.add(neighbor_idx)
357
+ queue.append(neighbor_idx)
358
+
359
+ if len(visited) != num_joints:
360
+ print(f"bones are not fully connected, leaving {num_joints - len(visited)} joints unconnected.")
361
+
362
+ # save joints
363
+ joints_lines = []
364
+ for idx, coord in enumerate(pred_joints):
365
+ name = joint_names[idx]
366
+ joints_line = f'joints {name} {coord[0]:.8f} {coord[1]:.8f} {coord[2]:.8f}'
367
+ joints_lines.append(joints_line)
368
+
369
+ # save root name
370
+ root_line = f'root {root_name}'
371
+
372
+ # save hierarchy
373
+ hier_lines = []
374
+ for child_idx, parent_idx in parent_map.items():
375
+ if parent_idx is not None:
376
+ parent_name = joint_names[parent_idx]
377
+ child_name = joint_names[child_idx]
378
+ hier_line = f'hier {parent_name} {child_name}'
379
+ hier_lines.append(hier_line)
380
+
381
+ with open(filename, 'w') as file:
382
+ for line in joints_lines:
383
+ file.write(line + '\n')
384
+
385
+ file.write(root_line + '\n')
386
+
387
+ for line in hier_lines:
388
+ file.write(line + '\n')
389
+
390
+ def save_skeleton_obj(joints, bones, save_path, root_index=None, radius_sphere=0.01,
391
+ radius_bone=0.005, segments=16, stacks=16, use_cone=False):
392
+ """
393
+ Save skeletons to obj file, each connection contains two red spheres (joint) and one blue cylinder (bone).
394
+ if root index is known, set root sphere to green.
395
+ """
396
+
397
+ all_vertices = []
398
+ all_colors = []
399
+ all_faces = []
400
+ vertex_offset = 0
401
+
402
+ # create spheres for joints
403
+ for i, joint in enumerate(joints):
404
+ # define color
405
+ if root_index is not None and i == root_index:
406
+ color = (0, 1, 0) # green for root joint
407
+ else:
408
+ color = (1, 0, 0) # red for other joints
409
+
410
+ # create joint sphere
411
+ sphere_vertices, sphere_faces = create_sphere(joint, radius=radius_sphere, segments=segments, stacks=stacks)
412
+ all_vertices.extend(sphere_vertices)
413
+ all_colors.extend([color] * len(sphere_vertices))
414
+
415
+ # adjust face index
416
+ adjusted_sphere_faces = [(v1 + vertex_offset, v2 + vertex_offset, v3 + vertex_offset) for (v1, v2, v3) in sphere_faces]
417
+ all_faces.extend(adjusted_sphere_faces)
418
+ vertex_offset += len(sphere_vertices)
419
+
420
+ # create bones
421
+ for bone in bones:
422
+ parent_idx, child_idx = bone
423
+ parent = joints[parent_idx]
424
+ child = joints[child_idx]
425
+
426
+ try:
427
+ bone_vertices, bone_faces = create_bone(parent, child, radius=radius_bone, segments=segments, use_cone=use_cone)
428
+ except ValueError as e:
429
+ print(f"Skipping connection {parent_idx}-{child_idx}, reason: {e}")
430
+ continue
431
+
432
+ all_vertices.extend(bone_vertices)
433
+ all_colors.extend([(0, 0, 1)] * len(bone_vertices)) # blue
434
+
435
+ # adjust face index
436
+ adjusted_bone_faces = [(v1 + vertex_offset, v2 + vertex_offset, v3 + vertex_offset) for (v1, v2, v3) in bone_faces]
437
+ all_faces.extend(adjusted_bone_faces)
438
+ vertex_offset += len(bone_vertices)
439
+
440
+ # save to obj
441
+ obj_lines = []
442
+ for v, c in zip(all_vertices, all_colors):
443
+ obj_lines.append(f"v {v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]}")
444
+ obj_lines.append("")
445
+
446
+ for face in all_faces:
447
+ obj_lines.append(f"f {face[0]} {face[1]} {face[2]}")
448
+
449
+ with open(save_path, 'w') as obj_file:
450
+ obj_file.write("\n".join(obj_lines))
451
+
452
+ def create_sphere(center, radius=0.01, segments=16, stacks=16):
453
+ vertices = []
454
+ faces = []
455
+ for i in range(stacks + 1):
456
+ lat = np.pi / 2 - i * np.pi / stacks
457
+ xy = radius * np.cos(lat)
458
+ z = radius * np.sin(lat)
459
+ for j in range(segments):
460
+ lon = j * 2 * np.pi / segments
461
+ x = xy * np.cos(lon) + center[0]
462
+ y = xy * np.sin(lon) + center[1]
463
+ vertices.append((x, y, z + center[2]))
464
+ for i in range(stacks):
465
+ for j in range(segments):
466
+ first = i * segments + j
467
+ second = first + segments
468
+ third = first + 1 if (j + 1) < segments else i * segments
469
+ fourth = second + 1 if (j + 1) < segments else (i + 1) * segments
470
+ faces.append((first + 1, second + 1, fourth + 1))
471
+ faces.append((first + 1, fourth + 1, third + 1))
472
+ return vertices, faces
473
+
474
+ def create_bone(start, end, radius=0.005, segments=16, use_cone=False):
475
+ dir_vector = np.array(end) - np.array(start)
476
+ height = np.linalg.norm(dir_vector)
477
+ if height == 0:
478
+ raise ValueError("Start and end points cannot be the same for a cone.")
479
+ dir_vector = dir_vector / height
480
+
481
+ z = np.array([0, 0, 1])
482
+ if np.allclose(dir_vector, z):
483
+ R = np.identity(3)
484
+ elif np.allclose(dir_vector, -z):
485
+ R = np.array([[-1,0,0],[0,-1,0],[0,0,1]])
486
+ else:
487
+ v = np.cross(z, dir_vector)
488
+ s = np.linalg.norm(v)
489
+ c = np.dot(z, dir_vector)
490
+ kmat = np.array([[0, -v[2], v[1]],
491
+ [v[2], 0, -v[0]],
492
+ [-v[1], v[0], 0]])
493
+ R = np.identity(3) + kmat + np.matmul(kmat, kmat) * ((1 - c) / (s**2))
494
+
495
+ theta = np.linspace(0, 2 * np.pi, segments, endpoint=False)
496
+ base_circle = np.array([np.cos(theta), np.sin(theta), np.zeros(segments)]) * radius
497
+
498
+ vertices = []
499
+ for point in base_circle.T:
500
+ rotated = np.dot(R, point) + np.array(start)
501
+ vertices.append(tuple(rotated))
502
+
503
+
504
+ faces = []
505
+
506
+ if use_cone:
507
+ vertices.append(tuple(end))
508
+
509
+ apex_idx = segments + 1
510
+ for i in range(segments):
511
+ next_i = (i + 1) % segments
512
+ faces.append((i + 1, next_i + 1, apex_idx))
513
+ else:
514
+ top_circle = np.array([np.cos(theta), np.sin(theta), np.ones(segments)]) * radius
515
+ for point in top_circle.T:
516
+ point_scaled = np.array([point[0], point[1], height])
517
+ rotated = np.dot(R, point_scaled) + np.array(start)
518
+ vertices.append(tuple(rotated))
519
+ for i in range(segments):
520
+ next_i = (i + 1) % segments
521
+ faces.append((i + 1, next_i + 1, next_i + segments + 1))
522
+ faces.append((i + 1, next_i + segments + 1, i + segments + 1))
523
+
524
+ return vertices, faces
525
+
526
+ def render_mesh_with_skeleton(joints, bones, vertices, faces, output_dir, filename, prefix='pred', root_idx=None):
527
+ """
528
+ Render the mesh with skeleton using PyRender.
529
+ """
530
+ loader = DataLoader()
531
+
532
+ raw_size = (960, 960)
533
+ renderer = PyRenderWrapper(raw_size)
534
+
535
+ save_dir = os.path.join(output_dir, 'render_results')
536
+ os.makedirs(save_dir, exist_ok=True)
537
+
538
+ loader.joints = joints
539
+ loader.bones = bones
540
+ loader.root_idx = root_idx
541
+
542
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
543
+ mesh.visual.vertex_colors[:, 3] = 100 # set transparency
544
+ loader.mesh = mesh
545
+ v = mesh.vertices
546
+ xmin, ymin, zmin = v.min(axis=0)
547
+ xmax, ymax, zmax = v.max(axis=0)
548
+ loader.bbox_center = np.array([(xmax + xmin)/2, (ymax + ymin)/2, (zmax + zmin)/2])
549
+ loader.bbox_size = np.array([xmax - xmin, ymax - ymin, zmax - zmin])
550
+ loader.bbox_scale = max(xmax - xmin, ymax - ymin, zmax - zmin)
551
+ loader.normalize_coordinates()
552
+
553
+ input_dict = loader.query_mesh_rig()
554
+
555
+ angles = [0, np.pi/2, np.pi, 3*np.pi/2]
556
+ distance = np.max(loader.bbox_size) * 2
557
+
558
+ subfolder_path = os.path.join(save_dir, filename + '_' + prefix)
559
+
560
+ os.makedirs(subfolder_path, exist_ok=True)
561
+
562
+ for i, angle in enumerate(angles):
563
+ renderer.set_camera_view(angle, loader.bbox_center, distance)
564
+ renderer.align_light_to_camera()
565
+
566
+ color = renderer.render(input_dict)[0]
567
+
568
+ output_filename = f"{filename}_{prefix}_view{i+1}.png"
569
+ output_filepath = os.path.join(subfolder_path, output_filename)
570
+ cv2.imwrite(output_filepath, color)
571
+
572
+
573
+ def save_args(args, output_dir, filename="config.json"):
574
+ args_dict = vars(args)
575
+ os.makedirs(output_dir, exist_ok=True)
576
+ config_path = os.path.join(output_dir, filename)
577
+ with open(config_path, 'w') as f:
578
+ json.dump(args_dict, f, indent=4)
utils/skeleton_data_loader.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torch import is_tensor
16
+ from torch.utils.data import Dataset
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ from data_utils.save_npz import normalize_to_unit_cube
19
+
20
+ import numpy as np
21
+
22
+ class SkeletonData(Dataset):
23
+ """
24
+ A PyTorch Dataset to load and process skeleton data.
25
+ """
26
+ def __init__(self, data, args, is_training):
27
+ self.data = data
28
+
29
+ self.input_pc_num = args.input_pc_num
30
+ self.is_training = is_training
31
+
32
+ self.hier_order = args.hier_order
33
+ print(f"[Dataset] Created from {len(self.data)} entries")
34
+
35
+ def __len__(self):
36
+ return len(self.data)
37
+
38
+ def __getitem__(self, idx):
39
+ data = self.data[idx]
40
+
41
+ joints = data['joints']
42
+ vertices = data['vertices']
43
+ pc_normal = data['pc_w_norm']
44
+
45
+ indices = np.random.choice(pc_normal.shape[0], self.input_pc_num, replace=False)
46
+ pc_normal = pc_normal[indices, :]
47
+
48
+ pc_coor = pc_normal[:, :3]
49
+ normal = pc_normal[:, 3:]
50
+ if np.linalg.norm(normal, axis=1, keepdims=True).min() < 0.99:
51
+ print("normal reroll")
52
+ return self.__getitem__(np.random.randint(0, len(self.data)))
53
+
54
+ data_dict = {}
55
+
56
+ # normalize normal
57
+ normal = normal / np.linalg.norm(normal, axis=1, keepdims=True)
58
+
59
+ # scale to -0.5 to 0.5
60
+ _, center, scale = normalize_to_unit_cube(vertices.copy(), scale_factor=0.9995)
61
+ joints = (joints - center) * scale # align joints with pc first
62
+
63
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
64
+ pc_center = (bounds[0] + bounds[1])[None, :] / 2
65
+ pc_scale = (bounds[1] - bounds[0]).max() + 1e-5
66
+ pc_coor = (pc_coor - pc_center) / pc_scale
67
+ joints = (joints - pc_center) / pc_scale
68
+
69
+ joints = joints.clip(-0.5, 0.5)
70
+
71
+ data_dict['joints'] = torch.from_numpy(np.asarray(joints).astype(np.float16))
72
+ data_dict['bones'] = torch.from_numpy(data['bones'].astype(np.int64))
73
+ pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
74
+ data_dict['pc_normal'] = torch.from_numpy(np.concatenate([pc_coor, normal], axis=-1).astype(np.float16))
75
+ data_dict['vertices'] = torch.from_numpy(data['vertices'].astype(np.float16))
76
+ data_dict['faces'] = torch.from_numpy(data['faces'].astype(np.int64))
77
+ data_dict['uuid'] = data['uuid']
78
+ data_dict['root_index'] = str(data['root_index'])
79
+ data_dict['transform_params'] = torch.tensor([
80
+ center[0], center[1], center[2],
81
+ scale,
82
+ pc_center[0][0], pc_center[0][1], pc_center[0][2],
83
+ pc_scale
84
+ ], dtype=torch.float32)
85
+
86
+ return data_dict
87
+
88
+ @classmethod
89
+ def load(cls, args, is_training=True):
90
+ loaded_data = np.load(args.dataset_path, allow_pickle=True)
91
+ data = []
92
+ for item in loaded_data["arr_0"]:
93
+ data.append(item)
94
+ print(f"[Dataset] Loaded {len(data)} entries")
95
+ return cls(data, args, is_training)
96
+
97
+