ckc99u commited on
Commit
bf5fbab
·
verified ·
1 Parent(s): dd0269a

Update RigNet/quick_start.py

Browse files
Files changed (1) hide show
  1. RigNet/quick_start.py +117 -233
RigNet/quick_start.py CHANGED
@@ -51,6 +51,7 @@ def normalize_obj(mesh_v):
51
  mesh_v *= scale
52
  return mesh_v, pivot, scale
53
 
 
54
  def create_single_data(mesh_filename):
55
  """
56
  create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
@@ -64,8 +65,7 @@ def create_single_data(mesh_filename):
64
  mesh_f = np.asarray(mesh.triangles)
65
  mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
66
  mesh_normalized = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(mesh_v), triangles=o3d.utility.Vector3iVector(mesh_f))
67
- normalized_obj = mesh_filename.replace("_remesh.obj", "_normalized.obj")
68
- o3d.io.write_triangle_mesh(normalized_obj, mesh_normalized)
69
 
70
  # vertices
71
  v = np.concatenate((mesh_v, mesh_vn), axis=1)
@@ -90,153 +90,37 @@ def create_single_data(mesh_filename):
90
  # batch
91
  batch = torch.zeros(len(v), dtype=torch.long)
92
 
93
- # voxel - Use trimesh instead of binvox (no X11 required)
94
  binvox_file = mesh_filename.replace('_remesh.obj', '_normalized.binvox')
 
95
 
96
  if not os.path.exists(binvox_file):
97
- print(f" voxelizing mesh (trimesh-based, no binvox)...")
98
-
99
- try:
100
- # Load mesh with trimesh
101
- mesh_tri = trimesh.load(normalized_obj)
102
-
103
- # Voxelize: create a 88x88x88 grid
104
- # Calculate pitch to fit mesh in 88^3 grid
105
- bounds = mesh_tri.bounds
106
- dims = bounds[1] - bounds[0]
107
- max_dim = max(dims)
108
- pitch = max_dim / 88.0
109
-
110
- # Create voxel grid
111
- voxel_grid = mesh_tri.voxelized(pitch=pitch)
112
-
113
- # Get current voxel matrix
114
- vox_matrix = voxel_grid.matrix
115
- current_shape = vox_matrix.shape
116
-
117
- print(f" Original voxel shape: {current_shape}")
118
-
119
- # Resize to exactly 88x88x88 by padding/cropping each dimension
120
- target_shape = (88, 88, 88)
121
- resized = np.zeros(target_shape, dtype=bool)
122
-
123
- # Calculate how much to copy in each dimension
124
- x_size = min(current_shape[0], target_shape[0])
125
- y_size = min(current_shape[1], target_shape[1])
126
- z_size = min(current_shape[2], target_shape[2])
127
-
128
- # Copy the overlapping region
129
- resized[:x_size, :y_size, :z_size] = vox_matrix[:x_size, :y_size, :z_size]
130
-
131
- vox_matrix = resized
132
- print(f" Resized voxel shape: {vox_matrix.shape}")
133
-
134
- # Create binvox-compatible object with ALL required attributes
135
- class Voxels:
136
- def __init__(self, data, dims, translate, scale, axis_order):
137
- self.data = data
138
- self.dims = dims
139
- self.translate = translate
140
- self.scale = scale
141
- self.axis_order = axis_order
142
-
143
- vox_obj = Voxels(
144
- data=vox_matrix,
145
- dims=[88, 88, 88],
146
- translate=[0.0, 0.0, 0.0],
147
- scale=1.0,
148
- axis_order='xyz'
149
- )
150
-
151
- # Save as binvox format for caching
152
- with open(binvox_file, 'wb') as f:
153
- binvox_rw.write(vox_obj, f)
154
-
155
- print(f" ✓ Voxelization complete: {binvox_file}")
156
-
157
- except Exception as e:
158
- print(f" ERROR: Trimesh voxelization failed: {e}")
159
- import traceback
160
- traceback.print_exc()
161
- raise Exception(f"Voxelization failed: {e}")
162
-
163
- # Load voxel data
164
- with open(binvox_file, 'rb') as fvox:
165
- vox = binvox_rw.read_as_3d_array(fvox)
166
-
167
- data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
168
- return data, vox, surface_geodesic, translation_normalize, scale_normalize
169
-
170
-
171
- # def create_single_data(mesh_filename):
172
- # """
173
- # create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
174
- # :param mesh_filename: name of the input mesh
175
- # :return: wrapped data, voxelized mesh, and geodesic distance matrix of all vertices
176
- # """
177
- # mesh = o3d.io.read_triangle_mesh(mesh_filename)
178
- # mesh.compute_vertex_normals()
179
- # mesh_v = np.asarray(mesh.vertices)
180
- # mesh_vn = np.asarray(mesh.vertex_normals)
181
- # mesh_f = np.asarray(mesh.triangles)
182
- # mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
183
- # mesh_normalized = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(mesh_v), triangles=o3d.utility.Vector3iVector(mesh_f))
184
- # o3d.io.write_triangle_mesh(mesh_filename.replace("_remesh.obj", "_normalized.obj"), mesh_normalized)
185
-
186
- # # vertices
187
- # v = np.concatenate((mesh_v, mesh_vn), axis=1)
188
- # v = torch.from_numpy(v).float()
189
-
190
- # # topology edges
191
- # print(" gathering topological edges.")
192
- # tpl_e = get_tpl_edges(mesh_v, mesh_f).T
193
- # tpl_e = torch.from_numpy(tpl_e).long()
194
- # tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
195
-
196
- # # surface geodesic distance matrix
197
- # print(" calculating surface geodesic matrix.")
198
- # surface_geodesic = calc_surface_geodesic(mesh)
199
-
200
- # # geodesic edges
201
- # print(" gathering geodesic edges.")
202
- # geo_e = get_geo_edges(surface_geodesic, mesh_v).T
203
- # geo_e = torch.from_numpy(geo_e).long()
204
- # geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
205
-
206
- # # batch
207
- # batch = torch.zeros(len(v), dtype=torch.long)
208
-
209
- # # voxel - FIXED: Use absolute path and better error handling
210
- # binvox_file = mesh_filename.replace('_remesh.obj', '_normalized.binvox')
211
- # normalized_obj = mesh_filename.replace("_remesh.obj", "_normalized.obj")
212
-
213
- # if not os.path.exists(binvox_file):
214
- # print(f" voxelizing mesh with binvox...")
215
 
216
- # # Use absolute path to binvox (installed in Dockerfile)
217
- # if platform == "linux" or platform == "linux2":
218
- # cmd = f"binvox -d 88 -pb {normalized_obj}"
219
- # elif platform == "win32":
220
- # cmd = f"binvox.exe -d 88 {normalized_obj}"
221
- # else:
222
- # raise Exception('Sorry, we currently only support windows and linux.')
223
 
224
- # print(f" Running: {cmd}")
225
- # exit_code = os.system(cmd)
226
 
227
- # if exit_code != 0:
228
- # raise Exception(f"binvox command failed with exit code {exit_code}. Command: {cmd}")
229
 
230
- # if not os.path.exists(binvox_file):
231
- # raise Exception(f"binvox did not create output file: {binvox_file}")
232
 
233
- # print(f" ✓ Voxelization complete: {binvox_file}")
234
 
235
- # with open(binvox_file, 'rb') as fvox:
236
- # vox = binvox_rw.read_as_3d_array(fvox)
237
 
238
- # data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
239
- # return data, vox, surface_geodesic, translation_normalize, scale_normalize
240
 
241
 
242
  def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None):
@@ -512,97 +396,97 @@ if __name__ == '__main__':
512
  # Change to False to be more accurate but less efficient.
513
  downsample_skinning = True
514
 
515
- # load all weights
516
- print("loading all networks...")
517
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
518
-
519
- jointNet = JOINTNET()
520
- jointNet.to(device)
521
- jointNet.eval()
522
- jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar')
523
- jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
524
- print(" joint prediction network loaded.")
525
-
526
- rootNet = ROOTNET()
527
- rootNet.to(device)
528
- rootNet.eval()
529
- rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
530
- rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
531
- print(" root prediction network loaded.")
532
-
533
- boneNet = BONENET()
534
- boneNet.to(device)
535
- boneNet.eval()
536
- boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
537
- boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
538
- print(" connection prediction network loaded.")
539
-
540
- skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
541
- skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
542
- skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
543
- skinNet.to(device)
544
- skinNet.eval()
545
- print(" skinning prediction network loaded.")
546
-
547
- # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
548
- # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
549
- # We also use these two default parameters for processing all test models in batch.
550
-
551
- #model_id, bandwidth, threshold = "smith", None, 1e-5
552
- model_id, bandwidth, threshold = "17872", 0.045, 0.75e-5
553
- #model_id, bandwidth, threshold = "8210", 0.05, 1e-5
554
- #model_id, bandwidth, threshold = "8330", 0.05, 0.8e-5
555
- #model_id, bandwidth, threshold = "9477", 0.043, 2.5e-5
556
- #model_id, bandwidth, threshold = "17364", 0.058, 0.3e-5
557
- #model_id, bandwidth, threshold = "15930", 0.055, 0.4e-5
558
- #model_id, bandwidth, threshold = "8333", 0.04, 2e-5
559
- #model_id, bandwidth, threshold = "8338", 0.052, 0.9e-5
560
- #model_id, bandwidth, threshold = "3318", 0.03, 0.92e-5
561
- #model_id, bandwidth, threshold = "15446", 0.032, 0.58e-5
562
- #model_id, bandwidth, threshold = "1347", 0.062, 3e-5
563
- #model_id, bandwidth, threshold = "11814", 0.06, 0.6e-5
564
- #model_id, bandwidth, threshold = "2982", 0.045, 0.3e-5
565
- #model_id, bandwidth, threshold = "2586", 0.05, 0.6e-5
566
- #model_id, bandwidth, threshold = "8184", 0.05, 0.4e-5
567
- #model_id, bandwidth, threshold = "9000", 0.035, 0.16e-5
568
-
569
- # create data used for inferece
570
- print("creating data for model ID {:s}".format(model_id))
571
- mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))
572
- if not os.path.exists(mesh_filename):
573
- mesh_ori_filename = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
574
- mesh_ori = o3d.io.read_triangle_mesh(mesh_ori_filename)
575
- if len(np.asarray(mesh_ori.vertices)) == 0:
576
- print(f"Please name your input model as {model_id}_ori.obj")
577
- exit()
578
- mesh_remesh = mesh_ori.simplify_quadric_decimation(4000) # adjust vertices between 1K - 5K
579
- o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
580
-
581
- data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
582
- data.to(device)
583
-
584
- print("predicting joints")
585
- data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
586
- mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
587
- data.to(device)
588
- print("predicting connectivity")
589
- pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
590
- mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
591
- print("predicting skinning")
592
- pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
593
- mesh_filename.replace("_remesh.obj", "_normalized.obj"),
594
- subsampling=downsample_skinning)
595
-
596
- # here we reverse the normalization to the original scale and position
597
- pred_rig.normalize(scale_normalize, -translation_normalize)
598
-
599
- print("Saving result")
600
- if True:
601
- # here we use original mesh tesselation (without remeshing)
602
- mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
603
- pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
604
- pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
605
- else:
606
- # here we use remeshed mesh
607
- pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
608
- print("Done!")
 
51
  mesh_v *= scale
52
  return mesh_v, pivot, scale
53
 
54
+
55
  def create_single_data(mesh_filename):
56
  """
57
  create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
 
65
  mesh_f = np.asarray(mesh.triangles)
66
  mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
67
  mesh_normalized = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(mesh_v), triangles=o3d.utility.Vector3iVector(mesh_f))
68
+ o3d.io.write_triangle_mesh(mesh_filename.replace("_remesh.obj", "_normalized.obj"), mesh_normalized)
 
69
 
70
  # vertices
71
  v = np.concatenate((mesh_v, mesh_vn), axis=1)
 
90
  # batch
91
  batch = torch.zeros(len(v), dtype=torch.long)
92
 
93
+ # voxel - FIXED: Use absolute path and better error handling
94
  binvox_file = mesh_filename.replace('_remesh.obj', '_normalized.binvox')
95
+ normalized_obj = mesh_filename.replace("_remesh.obj", "_normalized.obj")
96
 
97
  if not os.path.exists(binvox_file):
98
+ print(f" voxelizing mesh with binvox...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Use absolute path to binvox (installed in Dockerfile)
101
+ if platform == "linux" or platform == "linux2":
102
+ cmd = f"binvox -d 88 -pb {normalized_obj}"
103
+ elif platform == "win32":
104
+ cmd = f"binvox.exe -d 88 {normalized_obj}"
105
+ else:
106
+ raise Exception('Sorry, we currently only support windows and linux.')
107
 
108
+ print(f" Running: {cmd}")
109
+ exit_code = os.system(cmd)
110
 
111
+ if exit_code != 0:
112
+ raise Exception(f"binvox command failed with exit code {exit_code}. Command: {cmd}")
113
 
114
+ if not os.path.exists(binvox_file):
115
+ raise Exception(f"binvox did not create output file: {binvox_file}")
116
 
117
+ print(f" ✓ Voxelization complete: {binvox_file}")
118
 
119
+ with open(binvox_file, 'rb') as fvox:
120
+ vox = binvox_rw.read_as_3d_array(fvox)
121
 
122
+ data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
123
+ return data, vox, surface_geodesic, translation_normalize, scale_normalize
124
 
125
 
126
  def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None):
 
396
  # Change to False to be more accurate but less efficient.
397
  downsample_skinning = True
398
 
399
+ # # load all weights
400
+ # print("loading all networks...")
401
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
402
+
403
+ # jointNet = JOINTNET()
404
+ # jointNet.to(device)
405
+ # jointNet.eval()
406
+ # jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar')
407
+ # jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
408
+ # print(" joint prediction network loaded.")
409
+
410
+ # rootNet = ROOTNET()
411
+ # rootNet.to(device)
412
+ # rootNet.eval()
413
+ # rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
414
+ # rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
415
+ # print(" root prediction network loaded.")
416
+
417
+ # boneNet = BONENET()
418
+ # boneNet.to(device)
419
+ # boneNet.eval()
420
+ # boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
421
+ # boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
422
+ # print(" connection prediction network loaded.")
423
+
424
+ # skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
425
+ # skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
426
+ # skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
427
+ # skinNet.to(device)
428
+ # skinNet.eval()
429
+ # print(" skinning prediction network loaded.")
430
+
431
+ # # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
432
+ # # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
433
+ # # We also use these two default parameters for processing all test models in batch.
434
+
435
+ # #model_id, bandwidth, threshold = "smith", None, 1e-5
436
+ # model_id, bandwidth, threshold = "17872", 0.045, 0.75e-5
437
+ # #model_id, bandwidth, threshold = "8210", 0.05, 1e-5
438
+ # #model_id, bandwidth, threshold = "8330", 0.05, 0.8e-5
439
+ # #model_id, bandwidth, threshold = "9477", 0.043, 2.5e-5
440
+ # #model_id, bandwidth, threshold = "17364", 0.058, 0.3e-5
441
+ # #model_id, bandwidth, threshold = "15930", 0.055, 0.4e-5
442
+ # #model_id, bandwidth, threshold = "8333", 0.04, 2e-5
443
+ # #model_id, bandwidth, threshold = "8338", 0.052, 0.9e-5
444
+ # #model_id, bandwidth, threshold = "3318", 0.03, 0.92e-5
445
+ # #model_id, bandwidth, threshold = "15446", 0.032, 0.58e-5
446
+ # #model_id, bandwidth, threshold = "1347", 0.062, 3e-5
447
+ # #model_id, bandwidth, threshold = "11814", 0.06, 0.6e-5
448
+ # #model_id, bandwidth, threshold = "2982", 0.045, 0.3e-5
449
+ # #model_id, bandwidth, threshold = "2586", 0.05, 0.6e-5
450
+ # #model_id, bandwidth, threshold = "8184", 0.05, 0.4e-5
451
+ # #model_id, bandwidth, threshold = "9000", 0.035, 0.16e-5
452
+
453
+ # # create data used for inferece
454
+ # print("creating data for model ID {:s}".format(model_id))
455
+ # mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))
456
+ # if not os.path.exists(mesh_filename):
457
+ # mesh_ori_filename = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
458
+ # mesh_ori = o3d.io.read_triangle_mesh(mesh_ori_filename)
459
+ # if len(np.asarray(mesh_ori.vertices)) == 0:
460
+ # print(f"Please name your input model as {model_id}_ori.obj")
461
+ # exit()
462
+ # mesh_remesh = mesh_ori.simplify_quadric_decimation(4000) # adjust vertices between 1K - 5K
463
+ # o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
464
+
465
+ # data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
466
+ # data.to(device)
467
+
468
+ # print("predicting joints")
469
+ # data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
470
+ # mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
471
+ # data.to(device)
472
+ # print("predicting connectivity")
473
+ # pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
474
+ # mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
475
+ # print("predicting skinning")
476
+ # pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
477
+ # mesh_filename.replace("_remesh.obj", "_normalized.obj"),
478
+ # subsampling=downsample_skinning)
479
+
480
+ # # here we reverse the normalization to the original scale and position
481
+ # pred_rig.normalize(scale_normalize, -translation_normalize)
482
+
483
+ # print("Saving result")
484
+ # if True:
485
+ # # here we use original mesh tesselation (without remeshing)
486
+ # mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
487
+ # pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
488
+ # pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
489
+ # else:
490
+ # # here we use remeshed mesh
491
+ # pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
492
+ # print("Done!")