ckc99u commited on
Commit
bfb9e36
·
verified ·
1 Parent(s): 0fb16ad

Update RigNet/quick_start.py

Browse files
Files changed (1) hide show
  1. RigNet/quick_start.py +476 -476
RigNet/quick_start.py CHANGED
@@ -1,476 +1,476 @@
1
- # ---------------------------------------------------------------------------------------------------------
2
- # Name: quick_start.py
3
- # Purpose: An easy-to-use demo. Also serves as an interface of the pipeline.
4
- # RigNet Copyright 2020 University of Massachusetts
5
- # RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
6
- # Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
7
- # ---------------------------------------------------------------------------------------------------------
8
-
9
- import os
10
- from sys import platform
11
- import trimesh
12
- import numpy as np
13
- import open3d as o3d
14
- import itertools as it
15
-
16
- import torch
17
- from torch_geometric.data import Data
18
- from torch_geometric.utils import add_self_loops
19
-
20
- from utils import binvox_rw
21
- from utils.rig_parser import Skel, Info
22
- from utils.tree_utils import TreeNode
23
- from utils.io_utils import assemble_skel_skin
24
- from utils.vis_utils import draw_shifted_pts, show_obj_skel, show_mesh_vox
25
- from utils.cluster_utils import meanshift_cluster, nms_meanshift
26
- from utils.mst_utils import increase_cost_for_outside_bone, primMST_symmetry, loadSkel_recur, inside_check, flip
27
-
28
- from geometric_proc.common_ops import get_bones, calc_surface_geodesic
29
- from geometric_proc.compute_volumetric_geodesic import pts2line, calc_pts2bone_visible_mat
30
-
31
- from gen_dataset import get_tpl_edges, get_geo_edges
32
- from mst_generate import sample_on_bone, getInitId
33
- from run_skinning import post_filter
34
-
35
- from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
36
- from models.ROOT_GCN import ROOTNET
37
- from models.PairCls_GCN import PairCls as BONENET
38
- from models.SKINNING import SKINNET
39
-
40
-
41
- def normalize_obj(mesh_v):
42
- dims = [max(mesh_v[:, 0]) - min(mesh_v[:, 0]),
43
- max(mesh_v[:, 1]) - min(mesh_v[:, 1]),
44
- max(mesh_v[:, 2]) - min(mesh_v[:, 2])]
45
- scale = 1.0 / max(dims)
46
- pivot = np.array([(min(mesh_v[:, 0]) + max(mesh_v[:, 0])) / 2, min(mesh_v[:, 1]),
47
- (min(mesh_v[:, 2]) + max(mesh_v[:, 2])) / 2])
48
- mesh_v[:, 0] -= pivot[0]
49
- mesh_v[:, 1] -= pivot[1]
50
- mesh_v[:, 2] -= pivot[2]
51
- mesh_v *= scale
52
- return mesh_v, pivot, scale
53
-
54
-
55
- def create_single_data(mesh_filaname):
56
- """
57
- create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
58
- :param mesh_filaname: name of the input mesh
59
- :return: wrapped data, voxelized mesh, and geodesic distance matrix of all vertices
60
- """
61
- mesh = o3d.io.read_triangle_mesh(mesh_filaname)
62
- mesh.compute_vertex_normals()
63
- mesh_v = np.asarray(mesh.vertices)
64
- mesh_vn = np.asarray(mesh.vertex_normals)
65
- mesh_f = np.asarray(mesh.triangles)
66
-
67
- mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
68
- mesh_normalized = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(mesh_v), triangles=o3d.utility.Vector3iVector(mesh_f))
69
- o3d.io.write_triangle_mesh(mesh_filename.replace("_remesh.obj", "_normalized.obj"), mesh_normalized)
70
-
71
- # vertices
72
- v = np.concatenate((mesh_v, mesh_vn), axis=1)
73
- v = torch.from_numpy(v).float()
74
-
75
- # topology edges
76
- print(" gathering topological edges.")
77
- tpl_e = get_tpl_edges(mesh_v, mesh_f).T
78
- tpl_e = torch.from_numpy(tpl_e).long()
79
- tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
80
-
81
- # surface geodesic distance matrix
82
- print(" calculating surface geodesic matrix.")
83
- surface_geodesic = calc_surface_geodesic(mesh)
84
-
85
- # geodesic edges
86
- print(" gathering geodesic edges.")
87
- geo_e = get_geo_edges(surface_geodesic, mesh_v).T
88
- geo_e = torch.from_numpy(geo_e).long()
89
- geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
90
-
91
- # batch
92
- batch = torch.zeros(len(v), dtype=torch.long)
93
-
94
- # voxel
95
- if not os.path.exists(mesh_filaname.replace('_remesh.obj', '_normalized.binvox')):
96
- if platform == "linux" or platform == "linux2":
97
- os.system("./binvox -d 88 -pb " + mesh_filaname.replace("_remesh.obj", "_normalized.obj"))
98
- elif platform == "win32":
99
- os.system("binvox.exe -d 88 " + mesh_filaname.replace("_remesh.obj", "_normalized.obj"))
100
- else:
101
- raise Exception('Sorry, we currently only support windows and linux.')
102
-
103
- with open(mesh_filaname.replace('_remesh.obj', '_normalized.binvox'), 'rb') as fvox:
104
- vox = binvox_rw.read_as_3d_array(fvox)
105
-
106
- data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
107
- return data, vox, surface_geodesic, translation_normalize, scale_normalize
108
-
109
-
110
- def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None):
111
- """
112
- Predict joints
113
- :param input_data: wrapped input data
114
- :param vox: voxelized mesh
115
- :param joint_pred_net: network for predicting joints
116
- :param threshold: density threshold to filter out shifted points
117
- :param bandwidth: bandwidth for meanshift clustering
118
- :param mesh_filename: mesh filename for visualization
119
- :return: wrapped data with predicted joints, pair-wise bone representation added.
120
- """
121
- data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net(input_data)
122
- y_pred = data_displacement + input_data.pos
123
- y_pred_np = y_pred.data.cpu().numpy()
124
- attn_pred_np = attn_pred.data.cpu().numpy()
125
- y_pred_np, index_inside = inside_check(y_pred_np, vox)
126
- attn_pred_np = attn_pred_np[index_inside, :]
127
- y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3]
128
- attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3]
129
-
130
- # symmetrize points by reflecting
131
- y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]])
132
- y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0)
133
- attn_pred_np = np.tile(attn_pred_np, (2, 1))
134
-
135
- #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
136
- if bandwidth is None:
137
- bandwidth = bandwidth_pred.item()
138
- y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
139
- #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
140
-
141
- Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
142
- density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
143
- density = np.sum(density, axis=0)
144
- density_sum = np.sum(density)
145
- y_pred_np = y_pred_np[density / density_sum > threshold]
146
- attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0]
147
- density = density[density / density_sum > threshold]
148
-
149
- #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
150
- pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
151
- pred_joints, _ = flip(pred_joints)
152
- #img = draw_shifted_pts(mesh_filename, pred_joints)
153
-
154
- # prepare and add new data members
155
- pairs = list(it.combinations(range(pred_joints.shape[0]), 2))
156
- pair_attr = []
157
- for pr in pairs:
158
- dist = np.linalg.norm(pred_joints[pr[0]] - pred_joints[pr[1]])
159
- bone_samples = sample_on_bone(pred_joints[pr[0]], pred_joints[pr[1]])
160
- bone_samples_inside, _ = inside_check(bone_samples, vox)
161
- outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
162
- attr = np.array([dist, outside_proportion, 1])
163
- pair_attr.append(attr)
164
- pairs = np.array(pairs)
165
- pair_attr = np.array(pair_attr)
166
- pairs = torch.from_numpy(pairs).float()
167
- pair_attr = torch.from_numpy(pair_attr).float()
168
- pred_joints = torch.from_numpy(pred_joints).float()
169
- joints_batch = torch.zeros(len(pred_joints), dtype=torch.long)
170
- pairs_batch = torch.zeros(len(pairs), dtype=torch.long)
171
-
172
- input_data.joints = pred_joints
173
- input_data.pairs = pairs
174
- input_data.pair_attr = pair_attr
175
- input_data.joints_batch = joints_batch
176
- input_data.pairs_batch = pairs_batch
177
- return input_data
178
-
179
-
180
- def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename):
181
- """
182
- Predict skeleton structure based on joints
183
- :param input_data: wrapped data
184
- :param vox: voxelized mesh
185
- :param root_pred_net: network to predict root
186
- :param bone_pred_net: network to predict pairwise connectivity cost
187
- :param mesh_filename: meshfilename for debugging
188
- :return: predicted skeleton structure
189
- """
190
- root_id = getInitId(input_data, root_pred_net)
191
- pred_joints = input_data.joints.data.cpu().numpy()
192
-
193
- with torch.no_grad():
194
- connect_prob, _ = bone_pred_net(input_data, permute_joints=False)
195
- connect_prob = torch.sigmoid(connect_prob)
196
- pair_idx = input_data.pairs.long().data.cpu().numpy()
197
- prob_matrix = np.zeros((len(input_data.joints), len(input_data.joints)))
198
- prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
199
- prob_matrix = prob_matrix + prob_matrix.transpose()
200
- cost_matrix = -np.log(prob_matrix + 1e-10)
201
- cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
202
-
203
- pred_skel = Info()
204
- parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
205
- for i in range(len(parent)):
206
- if parent[i] == -1:
207
- pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
208
- break
209
- loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
210
- pred_skel.joint_pos = pred_skel.get_joint_dict()
211
- #show_mesh_vox(mesh_filename, vox, pred_skel.root)
212
- try:
213
- img = show_obj_skel(mesh_filename, pred_skel.root)
214
- except:
215
- print("Visualization is not supported on headless servers. Please consider other headless rendering methods.")
216
- return pred_skel
217
-
218
-
219
- def calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=False):
220
- """
221
- calculate volumetric geodesic distance from vertices to each bones
222
- :param bones: B*6 numpy array where each row stores the starting and ending joint position of a bone
223
- :param mesh_v: V*3 mesh vertices
224
- :param surface_geodesic: geodesic distance matrix of all vertices
225
- :param mesh_filename: mesh filename
226
- :return: an approaximate volumetric geodesic distance matrix V*B, were (v,b) is the distance from vertex v to bone b
227
- """
228
-
229
- if subsampling:
230
- mesh0 = o3d.io.read_triangle_mesh(mesh_filename)
231
- mesh0 = mesh0.simplify_quadric_decimation(3000)
232
- o3d.io.write_triangle_mesh(mesh_filename.replace(".obj", "_simplified.obj"), mesh0)
233
- mesh_trimesh = trimesh.load(mesh_filename.replace(".obj", "_simplified.obj"))
234
- subsamples_ids = np.random.choice(len(mesh_v), np.min((len(mesh_v), 1500)), replace=False)
235
- subsamples = mesh_v[subsamples_ids, :]
236
- surface_geodesic = surface_geodesic[subsamples_ids, :][:, subsamples_ids]
237
- else:
238
- mesh_trimesh = trimesh.load(mesh_filename)
239
- subsamples = mesh_v
240
- origins, ends, pts_bone_dist = pts2line(subsamples, bones)
241
- pts_bone_visibility = calc_pts2bone_visible_mat(mesh_trimesh, origins, ends)
242
- pts_bone_visibility = pts_bone_visibility.reshape(len(bones), len(subsamples)).transpose()
243
- pts_bone_dist = pts_bone_dist.reshape(len(bones), len(subsamples)).transpose()
244
- # remove visible points which are too far
245
- for b in range(pts_bone_visibility.shape[1]):
246
- visible_pts = np.argwhere(pts_bone_visibility[:, b] == 1).squeeze(1)
247
- if len(visible_pts) == 0:
248
- continue
249
- threshold_b = np.percentile(pts_bone_dist[visible_pts, b], 15)
250
- pts_bone_visibility[pts_bone_dist[:, b] > 1.3 * threshold_b, b] = False
251
-
252
- visible_matrix = np.zeros(pts_bone_visibility.shape)
253
- visible_matrix[np.where(pts_bone_visibility == 1)] = pts_bone_dist[np.where(pts_bone_visibility == 1)]
254
- for c in range(visible_matrix.shape[1]):
255
- unvisible_pts = np.argwhere(pts_bone_visibility[:, c] == 0).squeeze(1)
256
- visible_pts = np.argwhere(pts_bone_visibility[:, c] == 1).squeeze(1)
257
- if len(visible_pts) == 0:
258
- visible_matrix[:, c] = pts_bone_dist[:, c]
259
- continue
260
- for r in unvisible_pts:
261
- dist1 = np.min(surface_geodesic[r, visible_pts])
262
- nn_visible = visible_pts[np.argmin(surface_geodesic[r, visible_pts])]
263
- if np.isinf(dist1):
264
- visible_matrix[r, c] = 8.0 + pts_bone_dist[r, c]
265
- else:
266
- visible_matrix[r, c] = dist1 + visible_matrix[nn_visible, c]
267
- if subsampling:
268
- nn_dist = np.sum((mesh_v[:, np.newaxis, :] - subsamples[np.newaxis, ...])**2, axis=2)
269
- nn_ind = np.argmin(nn_dist, axis=1)
270
- visible_matrix = visible_matrix[nn_ind, :]
271
- os.remove(mesh_filename.replace(".obj", "_simplified.obj"))
272
- return visible_matrix
273
-
274
-
275
- def predict_skinning(input_data, pred_skel, skin_pred_net, surface_geodesic, mesh_filename, subsampling=False):
276
- """
277
- predict skinning
278
- :param input_data: wrapped input data
279
- :param pred_skel: predicted skeleton
280
- :param skin_pred_net: network to predict skinning weights
281
- :param surface_geodesic: geodesic distance matrix of all vertices
282
- :param mesh_filename: mesh filename
283
- :return: predicted rig with skinning weights information
284
- """
285
- global device, output_folder
286
- num_nearest_bone = 5
287
- bones, bone_names, bone_isleaf = get_bones(pred_skel)
288
- mesh_v = input_data.pos.data.cpu().numpy()
289
- print(" calculating volumetric geodesic distance from vertices to bone. This step takes some time...")
290
- geo_dist = calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=subsampling)
291
- input_samples = [] # joint_pos (x, y, z), (bone_id, 1/D)*5
292
- loss_mask = []
293
- skin_nn = []
294
- for v_id in range(len(mesh_v)):
295
- geo_dist_v = geo_dist[v_id]
296
- bone_id_near_to_far = np.argsort(geo_dist_v)
297
- this_sample = []
298
- this_nn = []
299
- this_mask = []
300
- for i in range(num_nearest_bone):
301
- if i >= len(bones):
302
- this_sample += bones[bone_id_near_to_far[0]].tolist()
303
- this_sample.append(1.0 / (geo_dist_v[bone_id_near_to_far[0]] + 1e-10))
304
- this_sample.append(bone_isleaf[bone_id_near_to_far[0]])
305
- this_nn.append(0)
306
- this_mask.append(0)
307
- else:
308
- skel_bone_id = bone_id_near_to_far[i]
309
- this_sample += bones[skel_bone_id].tolist()
310
- this_sample.append(1.0 / (geo_dist_v[skel_bone_id] + 1e-10))
311
- this_sample.append(bone_isleaf[skel_bone_id])
312
- this_nn.append(skel_bone_id)
313
- this_mask.append(1)
314
- input_samples.append(np.array(this_sample)[np.newaxis, :])
315
- skin_nn.append(np.array(this_nn)[np.newaxis, :])
316
- loss_mask.append(np.array(this_mask)[np.newaxis, :])
317
-
318
- skin_input = np.concatenate(input_samples, axis=0)
319
- loss_mask = np.concatenate(loss_mask, axis=0)
320
- skin_nn = np.concatenate(skin_nn, axis=0)
321
- skin_input = torch.from_numpy(skin_input).float()
322
- input_data.skin_input = skin_input
323
- input_data.to(device)
324
-
325
- skin_pred = skin_pred_net(input_data)
326
- skin_pred = torch.softmax(skin_pred, dim=1)
327
- skin_pred = skin_pred.data.cpu().numpy()
328
- skin_pred = skin_pred * loss_mask
329
-
330
- skin_nn = skin_nn[:, 0:num_nearest_bone]
331
- skin_pred_full = np.zeros((len(skin_pred), len(bone_names)))
332
- for v in range(len(skin_pred)):
333
- for nn_id in range(len(skin_nn[v, :])):
334
- skin_pred_full[v, skin_nn[v, nn_id]] = skin_pred[v, nn_id]
335
- print(" filtering skinning prediction")
336
- tpl_e = input_data.tpl_edge_index.data.cpu().numpy()
337
- skin_pred_full = post_filter(skin_pred_full, tpl_e, num_ring=1)
338
- skin_pred_full[skin_pred_full < np.max(skin_pred_full, axis=1, keepdims=True) * 0.35] = 0.0
339
- skin_pred_full = skin_pred_full / (skin_pred_full.sum(axis=1, keepdims=True) + 1e-10)
340
- skel_res = assemble_skel_skin(pred_skel, skin_pred_full)
341
- return skel_res
342
-
343
-
344
- def tranfer_to_ori_mesh(filename_ori, filename_remesh, pred_rig):
345
- """
346
- convert the predicted rig of remeshed model to the rig of the original model.
347
- Just assign skinning weight based on nearest neighbor
348
- :param filename_ori: original mesh filename
349
- :param filename_remesh: remeshed mesh filename
350
- :param pred_rig: predicted rig
351
- :return: predicted rig for original mesh
352
- """
353
- mesh_remesh = o3d.io.read_triangle_mesh(filename_remesh)
354
- mesh_ori = o3d.io.read_triangle_mesh(filename_ori)
355
- tranfer_rig = Info()
356
-
357
- vert_remesh = np.asarray(mesh_remesh.vertices)
358
- vert_ori = np.asarray(mesh_ori.vertices)
359
-
360
- vertice_distance = np.sqrt(np.sum((vert_ori[np.newaxis, ...] - vert_remesh[:, np.newaxis, :]) ** 2, axis=2))
361
- vertice_raw_id = np.argmin(vertice_distance, axis=0) # nearest vertex id on the fixed mesh for each vertex on the remeshed mesh
362
-
363
- tranfer_rig.root = pred_rig.root
364
- tranfer_rig.joint_pos = pred_rig.joint_pos
365
- new_skin = []
366
- for v in range(len(vert_ori)):
367
- skin_v = [v]
368
- v_nn = vertice_raw_id[v]
369
- skin_v += pred_rig.joint_skin[v_nn][1:]
370
- new_skin.append(skin_v)
371
- tranfer_rig.joint_skin = new_skin
372
- return tranfer_rig
373
-
374
-
375
- if __name__ == '__main__':
376
- input_folder = "quick_start/"
377
-
378
- # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
379
- # and to save cpu memory in skinning calculation.
380
- # Change to False to be more accurate but less efficient.
381
- downsample_skinning = True
382
-
383
- # load all weights
384
- print("loading all networks...")
385
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
386
-
387
- jointNet = JOINTNET()
388
- jointNet.to(device)
389
- jointNet.eval()
390
- jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar')
391
- jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
392
- print(" joint prediction network loaded.")
393
-
394
- rootNet = ROOTNET()
395
- rootNet.to(device)
396
- rootNet.eval()
397
- rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
398
- rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
399
- print(" root prediction network loaded.")
400
-
401
- boneNet = BONENET()
402
- boneNet.to(device)
403
- boneNet.eval()
404
- boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
405
- boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
406
- print(" connection prediction network loaded.")
407
-
408
- skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
409
- skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
410
- skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
411
- skinNet.to(device)
412
- skinNet.eval()
413
- print(" skinning prediction network loaded.")
414
-
415
- # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
416
- # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
417
- # We also use these two default parameters for processing all test models in batch.
418
-
419
- #model_id, bandwidth, threshold = "smith", None, 1e-5
420
- model_id, bandwidth, threshold = "17872", 0.045, 0.75e-5
421
- #model_id, bandwidth, threshold = "8210", 0.05, 1e-5
422
- #model_id, bandwidth, threshold = "8330", 0.05, 0.8e-5
423
- #model_id, bandwidth, threshold = "9477", 0.043, 2.5e-5
424
- #model_id, bandwidth, threshold = "17364", 0.058, 0.3e-5
425
- #model_id, bandwidth, threshold = "15930", 0.055, 0.4e-5
426
- #model_id, bandwidth, threshold = "8333", 0.04, 2e-5
427
- #model_id, bandwidth, threshold = "8338", 0.052, 0.9e-5
428
- #model_id, bandwidth, threshold = "3318", 0.03, 0.92e-5
429
- #model_id, bandwidth, threshold = "15446", 0.032, 0.58e-5
430
- #model_id, bandwidth, threshold = "1347", 0.062, 3e-5
431
- #model_id, bandwidth, threshold = "11814", 0.06, 0.6e-5
432
- #model_id, bandwidth, threshold = "2982", 0.045, 0.3e-5
433
- #model_id, bandwidth, threshold = "2586", 0.05, 0.6e-5
434
- #model_id, bandwidth, threshold = "8184", 0.05, 0.4e-5
435
- #model_id, bandwidth, threshold = "9000", 0.035, 0.16e-5
436
-
437
- # create data used for inferece
438
- print("creating data for model ID {:s}".format(model_id))
439
- mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))
440
- if not os.path.exists(mesh_filename):
441
- mesh_ori_filename = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
442
- mesh_ori = o3d.io.read_triangle_mesh(mesh_ori_filename)
443
- if len(np.asarray(mesh_ori.vertices)) == 0:
444
- print(f"Please name your input model as {model_id}_ori.obj")
445
- exit()
446
- mesh_remesh = mesh_ori.simplify_quadric_decimation(4000) # adjust vertices between 1K - 5K
447
- o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
448
-
449
- data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
450
- data.to(device)
451
-
452
- print("predicting joints")
453
- data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
454
- mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
455
- data.to(device)
456
- print("predicting connectivity")
457
- pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
458
- mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
459
- print("predicting skinning")
460
- pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
461
- mesh_filename.replace("_remesh.obj", "_normalized.obj"),
462
- subsampling=downsample_skinning)
463
-
464
- # here we reverse the normalization to the original scale and position
465
- pred_rig.normalize(scale_normalize, -translation_normalize)
466
-
467
- print("Saving result")
468
- if True:
469
- # here we use original mesh tesselation (without remeshing)
470
- mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
471
- pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
472
- pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
473
- else:
474
- # here we use remeshed mesh
475
- pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
476
- print("Done!")
 
1
+ # ---------------------------------------------------------------------------------------------------------
2
+ # Name: quick_start.py
3
+ # Purpose: An easy-to-use demo. Also serves as an interface of the pipeline.
4
+ # RigNet Copyright 2020 University of Massachusetts
5
+ # RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
6
+ # Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
7
+ # ---------------------------------------------------------------------------------------------------------
8
+
9
+ import os
10
+ from sys import platform
11
+ import trimesh
12
+ import numpy as np
13
+ import open3d as o3d
14
+ import itertools as it
15
+
16
+ import torch
17
+ from torch_geometric.data import Data
18
+ from torch_geometric.utils import add_self_loops
19
+
20
+ from utils import binvox_rw
21
+ from utils.rig_parser import Skel, Info
22
+ from utils.tree_utils import TreeNode
23
+ from utils.io_utils import assemble_skel_skin
24
+ from utils.vis_utils import draw_shifted_pts, show_obj_skel, show_mesh_vox
25
+ from utils.cluster_utils import meanshift_cluster, nms_meanshift
26
+ from utils.mst_utils import increase_cost_for_outside_bone, primMST_symmetry, loadSkel_recur, inside_check, flip
27
+
28
+ from geometric_proc.common_ops import get_bones, calc_surface_geodesic
29
+ from geometric_proc.compute_volumetric_geodesic import pts2line, calc_pts2bone_visible_mat
30
+
31
+ from gen_dataset import get_tpl_edges, get_geo_edges
32
+ from mst_generate import sample_on_bone, getInitId
33
+ from run_skinning import post_filter
34
+
35
+ from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
36
+ from models.ROOT_GCN import ROOTNET
37
+ from models.PairCls_GCN import PairCls as BONENET
38
+ from models.SKINNING import SKINNET
39
+
40
+
41
+ def normalize_obj(mesh_v):
42
+ dims = [max(mesh_v[:, 0]) - min(mesh_v[:, 0]),
43
+ max(mesh_v[:, 1]) - min(mesh_v[:, 1]),
44
+ max(mesh_v[:, 2]) - min(mesh_v[:, 2])]
45
+ scale = 1.0 / max(dims)
46
+ pivot = np.array([(min(mesh_v[:, 0]) + max(mesh_v[:, 0])) / 2, min(mesh_v[:, 1]),
47
+ (min(mesh_v[:, 2]) + max(mesh_v[:, 2])) / 2])
48
+ mesh_v[:, 0] -= pivot[0]
49
+ mesh_v[:, 1] -= pivot[1]
50
+ mesh_v[:, 2] -= pivot[2]
51
+ mesh_v *= scale
52
+ return mesh_v, pivot, scale
53
+
54
+
55
+ def create_single_data(mesh_filename):
56
+ """
57
+ create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
58
+ :param mesh_filename: name of the input mesh
59
+ :return: wrapped data, voxelized mesh, and geodesic distance matrix of all vertices
60
+ """
61
+ mesh = o3d.io.read_triangle_mesh(mesh_filename)
62
+ mesh.compute_vertex_normals()
63
+ mesh_v = np.asarray(mesh.vertices)
64
+ mesh_vn = np.asarray(mesh.vertex_normals)
65
+ mesh_f = np.asarray(mesh.triangles)
66
+
67
+ mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
68
+ mesh_normalized = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(mesh_v), triangles=o3d.utility.Vector3iVector(mesh_f))
69
+ o3d.io.write_triangle_mesh(mesh_filename.replace("_remesh.obj", "_normalized.obj"), mesh_normalized)
70
+
71
+ # vertices
72
+ v = np.concatenate((mesh_v, mesh_vn), axis=1)
73
+ v = torch.from_numpy(v).float()
74
+
75
+ # topology edges
76
+ print(" gathering topological edges.")
77
+ tpl_e = get_tpl_edges(mesh_v, mesh_f).T
78
+ tpl_e = torch.from_numpy(tpl_e).long()
79
+ tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
80
+
81
+ # surface geodesic distance matrix
82
+ print(" calculating surface geodesic matrix.")
83
+ surface_geodesic = calc_surface_geodesic(mesh)
84
+
85
+ # geodesic edges
86
+ print(" gathering geodesic edges.")
87
+ geo_e = get_geo_edges(surface_geodesic, mesh_v).T
88
+ geo_e = torch.from_numpy(geo_e).long()
89
+ geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
90
+
91
+ # batch
92
+ batch = torch.zeros(len(v), dtype=torch.long)
93
+
94
+ # voxel
95
+ if not os.path.exists(mesh_filename.replace('_remesh.obj', '_normalized.binvox')):
96
+ if platform == "linux" or platform == "linux2":
97
+ os.system("./binvox -d 88 -pb " + mesh_filename.replace("_remesh.obj", "_normalized.obj"))
98
+ elif platform == "win32":
99
+ os.system("binvox.exe -d 88 " + mesh_filename.replace("_remesh.obj", "_normalized.obj"))
100
+ else:
101
+ raise Exception('Sorry, we currently only support windows and linux.')
102
+
103
+ with open(mesh_filename.replace('_remesh.obj', '_normalized.binvox'), 'rb') as fvox:
104
+ vox = binvox_rw.read_as_3d_array(fvox)
105
+
106
+ data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
107
+ return data, vox, surface_geodesic, translation_normalize, scale_normalize
108
+
109
+
110
+ def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None):
111
+ """
112
+ Predict joints
113
+ :param input_data: wrapped input data
114
+ :param vox: voxelized mesh
115
+ :param joint_pred_net: network for predicting joints
116
+ :param threshold: density threshold to filter out shifted points
117
+ :param bandwidth: bandwidth for meanshift clustering
118
+ :param mesh_filename: mesh filename for visualization
119
+ :return: wrapped data with predicted joints, pair-wise bone representation added.
120
+ """
121
+ data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net(input_data)
122
+ y_pred = data_displacement + input_data.pos
123
+ y_pred_np = y_pred.data.cpu().numpy()
124
+ attn_pred_np = attn_pred.data.cpu().numpy()
125
+ y_pred_np, index_inside = inside_check(y_pred_np, vox)
126
+ attn_pred_np = attn_pred_np[index_inside, :]
127
+ y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3]
128
+ attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3]
129
+
130
+ # symmetrize points by reflecting
131
+ y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]])
132
+ y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0)
133
+ attn_pred_np = np.tile(attn_pred_np, (2, 1))
134
+
135
+ #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
136
+ if bandwidth is None:
137
+ bandwidth = bandwidth_pred.item()
138
+ y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
139
+ #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
140
+
141
+ Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
142
+ density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
143
+ density = np.sum(density, axis=0)
144
+ density_sum = np.sum(density)
145
+ y_pred_np = y_pred_np[density / density_sum > threshold]
146
+ attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0]
147
+ density = density[density / density_sum > threshold]
148
+
149
+ #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
150
+ pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
151
+ pred_joints, _ = flip(pred_joints)
152
+ #img = draw_shifted_pts(mesh_filename, pred_joints)
153
+
154
+ # prepare and add new data members
155
+ pairs = list(it.combinations(range(pred_joints.shape[0]), 2))
156
+ pair_attr = []
157
+ for pr in pairs:
158
+ dist = np.linalg.norm(pred_joints[pr[0]] - pred_joints[pr[1]])
159
+ bone_samples = sample_on_bone(pred_joints[pr[0]], pred_joints[pr[1]])
160
+ bone_samples_inside, _ = inside_check(bone_samples, vox)
161
+ outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
162
+ attr = np.array([dist, outside_proportion, 1])
163
+ pair_attr.append(attr)
164
+ pairs = np.array(pairs)
165
+ pair_attr = np.array(pair_attr)
166
+ pairs = torch.from_numpy(pairs).float()
167
+ pair_attr = torch.from_numpy(pair_attr).float()
168
+ pred_joints = torch.from_numpy(pred_joints).float()
169
+ joints_batch = torch.zeros(len(pred_joints), dtype=torch.long)
170
+ pairs_batch = torch.zeros(len(pairs), dtype=torch.long)
171
+
172
+ input_data.joints = pred_joints
173
+ input_data.pairs = pairs
174
+ input_data.pair_attr = pair_attr
175
+ input_data.joints_batch = joints_batch
176
+ input_data.pairs_batch = pairs_batch
177
+ return input_data
178
+
179
+
180
+ def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename):
181
+ """
182
+ Predict skeleton structure based on joints
183
+ :param input_data: wrapped data
184
+ :param vox: voxelized mesh
185
+ :param root_pred_net: network to predict root
186
+ :param bone_pred_net: network to predict pairwise connectivity cost
187
+ :param mesh_filename: meshfilename for debugging
188
+ :return: predicted skeleton structure
189
+ """
190
+ root_id = getInitId(input_data, root_pred_net)
191
+ pred_joints = input_data.joints.data.cpu().numpy()
192
+
193
+ with torch.no_grad():
194
+ connect_prob, _ = bone_pred_net(input_data, permute_joints=False)
195
+ connect_prob = torch.sigmoid(connect_prob)
196
+ pair_idx = input_data.pairs.long().data.cpu().numpy()
197
+ prob_matrix = np.zeros((len(input_data.joints), len(input_data.joints)))
198
+ prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
199
+ prob_matrix = prob_matrix + prob_matrix.transpose()
200
+ cost_matrix = -np.log(prob_matrix + 1e-10)
201
+ cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
202
+
203
+ pred_skel = Info()
204
+ parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
205
+ for i in range(len(parent)):
206
+ if parent[i] == -1:
207
+ pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
208
+ break
209
+ loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
210
+ pred_skel.joint_pos = pred_skel.get_joint_dict()
211
+ #show_mesh_vox(mesh_filename, vox, pred_skel.root)
212
+ try:
213
+ img = show_obj_skel(mesh_filename, pred_skel.root)
214
+ except:
215
+ print("Visualization is not supported on headless servers. Please consider other headless rendering methods.")
216
+ return pred_skel
217
+
218
+
219
+ def calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=False):
220
+ """
221
+ calculate volumetric geodesic distance from vertices to each bones
222
+ :param bones: B*6 numpy array where each row stores the starting and ending joint position of a bone
223
+ :param mesh_v: V*3 mesh vertices
224
+ :param surface_geodesic: geodesic distance matrix of all vertices
225
+ :param mesh_filename: mesh filename
226
+ :return: an approaximate volumetric geodesic distance matrix V*B, were (v,b) is the distance from vertex v to bone b
227
+ """
228
+
229
+ if subsampling:
230
+ mesh0 = o3d.io.read_triangle_mesh(mesh_filename)
231
+ mesh0 = mesh0.simplify_quadric_decimation(3000)
232
+ o3d.io.write_triangle_mesh(mesh_filename.replace(".obj", "_simplified.obj"), mesh0)
233
+ mesh_trimesh = trimesh.load(mesh_filename.replace(".obj", "_simplified.obj"))
234
+ subsamples_ids = np.random.choice(len(mesh_v), np.min((len(mesh_v), 1500)), replace=False)
235
+ subsamples = mesh_v[subsamples_ids, :]
236
+ surface_geodesic = surface_geodesic[subsamples_ids, :][:, subsamples_ids]
237
+ else:
238
+ mesh_trimesh = trimesh.load(mesh_filename)
239
+ subsamples = mesh_v
240
+ origins, ends, pts_bone_dist = pts2line(subsamples, bones)
241
+ pts_bone_visibility = calc_pts2bone_visible_mat(mesh_trimesh, origins, ends)
242
+ pts_bone_visibility = pts_bone_visibility.reshape(len(bones), len(subsamples)).transpose()
243
+ pts_bone_dist = pts_bone_dist.reshape(len(bones), len(subsamples)).transpose()
244
+ # remove visible points which are too far
245
+ for b in range(pts_bone_visibility.shape[1]):
246
+ visible_pts = np.argwhere(pts_bone_visibility[:, b] == 1).squeeze(1)
247
+ if len(visible_pts) == 0:
248
+ continue
249
+ threshold_b = np.percentile(pts_bone_dist[visible_pts, b], 15)
250
+ pts_bone_visibility[pts_bone_dist[:, b] > 1.3 * threshold_b, b] = False
251
+
252
+ visible_matrix = np.zeros(pts_bone_visibility.shape)
253
+ visible_matrix[np.where(pts_bone_visibility == 1)] = pts_bone_dist[np.where(pts_bone_visibility == 1)]
254
+ for c in range(visible_matrix.shape[1]):
255
+ unvisible_pts = np.argwhere(pts_bone_visibility[:, c] == 0).squeeze(1)
256
+ visible_pts = np.argwhere(pts_bone_visibility[:, c] == 1).squeeze(1)
257
+ if len(visible_pts) == 0:
258
+ visible_matrix[:, c] = pts_bone_dist[:, c]
259
+ continue
260
+ for r in unvisible_pts:
261
+ dist1 = np.min(surface_geodesic[r, visible_pts])
262
+ nn_visible = visible_pts[np.argmin(surface_geodesic[r, visible_pts])]
263
+ if np.isinf(dist1):
264
+ visible_matrix[r, c] = 8.0 + pts_bone_dist[r, c]
265
+ else:
266
+ visible_matrix[r, c] = dist1 + visible_matrix[nn_visible, c]
267
+ if subsampling:
268
+ nn_dist = np.sum((mesh_v[:, np.newaxis, :] - subsamples[np.newaxis, ...])**2, axis=2)
269
+ nn_ind = np.argmin(nn_dist, axis=1)
270
+ visible_matrix = visible_matrix[nn_ind, :]
271
+ os.remove(mesh_filename.replace(".obj", "_simplified.obj"))
272
+ return visible_matrix
273
+
274
+
275
+ def predict_skinning(input_data, pred_skel, skin_pred_net, surface_geodesic, mesh_filename, subsampling=False):
276
+ """
277
+ predict skinning
278
+ :param input_data: wrapped input data
279
+ :param pred_skel: predicted skeleton
280
+ :param skin_pred_net: network to predict skinning weights
281
+ :param surface_geodesic: geodesic distance matrix of all vertices
282
+ :param mesh_filename: mesh filename
283
+ :return: predicted rig with skinning weights information
284
+ """
285
+ global device, output_folder
286
+ num_nearest_bone = 5
287
+ bones, bone_names, bone_isleaf = get_bones(pred_skel)
288
+ mesh_v = input_data.pos.data.cpu().numpy()
289
+ print(" calculating volumetric geodesic distance from vertices to bone. This step takes some time...")
290
+ geo_dist = calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=subsampling)
291
+ input_samples = [] # joint_pos (x, y, z), (bone_id, 1/D)*5
292
+ loss_mask = []
293
+ skin_nn = []
294
+ for v_id in range(len(mesh_v)):
295
+ geo_dist_v = geo_dist[v_id]
296
+ bone_id_near_to_far = np.argsort(geo_dist_v)
297
+ this_sample = []
298
+ this_nn = []
299
+ this_mask = []
300
+ for i in range(num_nearest_bone):
301
+ if i >= len(bones):
302
+ this_sample += bones[bone_id_near_to_far[0]].tolist()
303
+ this_sample.append(1.0 / (geo_dist_v[bone_id_near_to_far[0]] + 1e-10))
304
+ this_sample.append(bone_isleaf[bone_id_near_to_far[0]])
305
+ this_nn.append(0)
306
+ this_mask.append(0)
307
+ else:
308
+ skel_bone_id = bone_id_near_to_far[i]
309
+ this_sample += bones[skel_bone_id].tolist()
310
+ this_sample.append(1.0 / (geo_dist_v[skel_bone_id] + 1e-10))
311
+ this_sample.append(bone_isleaf[skel_bone_id])
312
+ this_nn.append(skel_bone_id)
313
+ this_mask.append(1)
314
+ input_samples.append(np.array(this_sample)[np.newaxis, :])
315
+ skin_nn.append(np.array(this_nn)[np.newaxis, :])
316
+ loss_mask.append(np.array(this_mask)[np.newaxis, :])
317
+
318
+ skin_input = np.concatenate(input_samples, axis=0)
319
+ loss_mask = np.concatenate(loss_mask, axis=0)
320
+ skin_nn = np.concatenate(skin_nn, axis=0)
321
+ skin_input = torch.from_numpy(skin_input).float()
322
+ input_data.skin_input = skin_input
323
+ input_data.to(device)
324
+
325
+ skin_pred = skin_pred_net(input_data)
326
+ skin_pred = torch.softmax(skin_pred, dim=1)
327
+ skin_pred = skin_pred.data.cpu().numpy()
328
+ skin_pred = skin_pred * loss_mask
329
+
330
+ skin_nn = skin_nn[:, 0:num_nearest_bone]
331
+ skin_pred_full = np.zeros((len(skin_pred), len(bone_names)))
332
+ for v in range(len(skin_pred)):
333
+ for nn_id in range(len(skin_nn[v, :])):
334
+ skin_pred_full[v, skin_nn[v, nn_id]] = skin_pred[v, nn_id]
335
+ print(" filtering skinning prediction")
336
+ tpl_e = input_data.tpl_edge_index.data.cpu().numpy()
337
+ skin_pred_full = post_filter(skin_pred_full, tpl_e, num_ring=1)
338
+ skin_pred_full[skin_pred_full < np.max(skin_pred_full, axis=1, keepdims=True) * 0.35] = 0.0
339
+ skin_pred_full = skin_pred_full / (skin_pred_full.sum(axis=1, keepdims=True) + 1e-10)
340
+ skel_res = assemble_skel_skin(pred_skel, skin_pred_full)
341
+ return skel_res
342
+
343
+
344
+ def tranfer_to_ori_mesh(filename_ori, filename_remesh, pred_rig):
345
+ """
346
+ convert the predicted rig of remeshed model to the rig of the original model.
347
+ Just assign skinning weight based on nearest neighbor
348
+ :param filename_ori: original mesh filename
349
+ :param filename_remesh: remeshed mesh filename
350
+ :param pred_rig: predicted rig
351
+ :return: predicted rig for original mesh
352
+ """
353
+ mesh_remesh = o3d.io.read_triangle_mesh(filename_remesh)
354
+ mesh_ori = o3d.io.read_triangle_mesh(filename_ori)
355
+ tranfer_rig = Info()
356
+
357
+ vert_remesh = np.asarray(mesh_remesh.vertices)
358
+ vert_ori = np.asarray(mesh_ori.vertices)
359
+
360
+ vertice_distance = np.sqrt(np.sum((vert_ori[np.newaxis, ...] - vert_remesh[:, np.newaxis, :]) ** 2, axis=2))
361
+ vertice_raw_id = np.argmin(vertice_distance, axis=0) # nearest vertex id on the fixed mesh for each vertex on the remeshed mesh
362
+
363
+ tranfer_rig.root = pred_rig.root
364
+ tranfer_rig.joint_pos = pred_rig.joint_pos
365
+ new_skin = []
366
+ for v in range(len(vert_ori)):
367
+ skin_v = [v]
368
+ v_nn = vertice_raw_id[v]
369
+ skin_v += pred_rig.joint_skin[v_nn][1:]
370
+ new_skin.append(skin_v)
371
+ tranfer_rig.joint_skin = new_skin
372
+ return tranfer_rig
373
+
374
+
375
+ if __name__ == '__main__':
376
+ input_folder = "quick_start/"
377
+
378
+ # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
379
+ # and to save cpu memory in skinning calculation.
380
+ # Change to False to be more accurate but less efficient.
381
+ downsample_skinning = True
382
+
383
+ # load all weights
384
+ print("loading all networks...")
385
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
386
+
387
+ jointNet = JOINTNET()
388
+ jointNet.to(device)
389
+ jointNet.eval()
390
+ jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar')
391
+ jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
392
+ print(" joint prediction network loaded.")
393
+
394
+ rootNet = ROOTNET()
395
+ rootNet.to(device)
396
+ rootNet.eval()
397
+ rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
398
+ rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
399
+ print(" root prediction network loaded.")
400
+
401
+ boneNet = BONENET()
402
+ boneNet.to(device)
403
+ boneNet.eval()
404
+ boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
405
+ boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
406
+ print(" connection prediction network loaded.")
407
+
408
+ skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
409
+ skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
410
+ skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
411
+ skinNet.to(device)
412
+ skinNet.eval()
413
+ print(" skinning prediction network loaded.")
414
+
415
+ # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
416
+ # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
417
+ # We also use these two default parameters for processing all test models in batch.
418
+
419
+ #model_id, bandwidth, threshold = "smith", None, 1e-5
420
+ model_id, bandwidth, threshold = "17872", 0.045, 0.75e-5
421
+ #model_id, bandwidth, threshold = "8210", 0.05, 1e-5
422
+ #model_id, bandwidth, threshold = "8330", 0.05, 0.8e-5
423
+ #model_id, bandwidth, threshold = "9477", 0.043, 2.5e-5
424
+ #model_id, bandwidth, threshold = "17364", 0.058, 0.3e-5
425
+ #model_id, bandwidth, threshold = "15930", 0.055, 0.4e-5
426
+ #model_id, bandwidth, threshold = "8333", 0.04, 2e-5
427
+ #model_id, bandwidth, threshold = "8338", 0.052, 0.9e-5
428
+ #model_id, bandwidth, threshold = "3318", 0.03, 0.92e-5
429
+ #model_id, bandwidth, threshold = "15446", 0.032, 0.58e-5
430
+ #model_id, bandwidth, threshold = "1347", 0.062, 3e-5
431
+ #model_id, bandwidth, threshold = "11814", 0.06, 0.6e-5
432
+ #model_id, bandwidth, threshold = "2982", 0.045, 0.3e-5
433
+ #model_id, bandwidth, threshold = "2586", 0.05, 0.6e-5
434
+ #model_id, bandwidth, threshold = "8184", 0.05, 0.4e-5
435
+ #model_id, bandwidth, threshold = "9000", 0.035, 0.16e-5
436
+
437
+ # create data used for inferece
438
+ print("creating data for model ID {:s}".format(model_id))
439
+ mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))
440
+ if not os.path.exists(mesh_filename):
441
+ mesh_ori_filename = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
442
+ mesh_ori = o3d.io.read_triangle_mesh(mesh_ori_filename)
443
+ if len(np.asarray(mesh_ori.vertices)) == 0:
444
+ print(f"Please name your input model as {model_id}_ori.obj")
445
+ exit()
446
+ mesh_remesh = mesh_ori.simplify_quadric_decimation(4000) # adjust vertices between 1K - 5K
447
+ o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
448
+
449
+ data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
450
+ data.to(device)
451
+
452
+ print("predicting joints")
453
+ data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
454
+ mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
455
+ data.to(device)
456
+ print("predicting connectivity")
457
+ pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
458
+ mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
459
+ print("predicting skinning")
460
+ pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
461
+ mesh_filename.replace("_remesh.obj", "_normalized.obj"),
462
+ subsampling=downsample_skinning)
463
+
464
+ # here we reverse the normalization to the original scale and position
465
+ pred_rig.normalize(scale_normalize, -translation_normalize)
466
+
467
+ print("Saving result")
468
+ if True:
469
+ # here we use original mesh tesselation (without remeshing)
470
+ mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
471
+ pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
472
+ pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
473
+ else:
474
+ # here we use remeshed mesh
475
+ pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
476
+ print("Done!")