ckc99u commited on
Commit
a5ba94d
·
verified ·
1 Parent(s): 2295ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +488 -132
app.py CHANGED
@@ -1,158 +1,514 @@
 
 
 
 
 
 
1
  import os
2
  import sys
3
- import shutil
4
- import platform
5
  import torch
6
  import numpy as np
7
  import open3d as o3d
8
- import gradio as gr
 
 
 
9
  from torch_geometric.data import Data
 
10
 
11
- # --- 0. Compatibility Patch for PyG 2.0+ ---
12
- # RigNet uses older PyG where you could assign arbitrary attributes (data.joints = ...).
13
- # In PyG 2.0+, 'Data' stores everything in a store. We subclass to allow dot-notation access.
14
- class RigNetData(Data):
15
- def __setattr__(self, key, value):
16
- if key in ['joints', 'pairs', 'pair_attr', 'joints_batch', 'pairs_batch', 'skin_input']:
17
- self[key] = value
18
- else:
19
- super().__setattr__(key, value)
20
-
21
- def __getattr__(self, key):
22
- # Fallback to getting from the dictionary if standard attribute fails
23
- if key in self.keys:
24
- return self[key]
25
- return super().__getattr__(key)
26
-
27
- # Monkey-patch RigNet's create_single_data later to use this class,
28
- # or just convert the object after creation.
29
 
30
- # --- 1. Setup Environment & Paths ---
31
- RIGNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "RigNet")
32
- if RIGNET_PATH not in sys.path:
33
- sys.path.append(RIGNET_PATH)
34
 
35
- # Ensure binvox executable
36
- BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
37
- BINVOX_DEST = "binvox.exe" if platform.system() == "Windows" else "binvox"
38
- if platform.system() == "Windows": BINVOX_SRC += ".exe"
39
-
40
- if os.path.exists(BINVOX_SRC):
41
- shutil.copy(BINVOX_SRC, BINVOX_DEST)
42
- if platform.system() != "Windows": os.system(f"chmod +x {BINVOX_DEST}")
43
- else:
44
- print(f"Warning: binvox not found at {BINVOX_SRC}")
45
-
46
- # --- 2. Import RigNet Modules ---
47
- try:
48
- # We need to intercept imports to inject our patched Data class if possible,
49
- # but easier to just wrap the result of create_single_data.
50
- from quick_start import (
51
- create_single_data as original_create_single_data, # Rename to wrap
52
- predict_joints, predict_skeleton,
53
- predict_skinning, tranfer_to_ori_mesh
54
- )
55
- from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
56
- from models.ROOT_GCN import ROOTNET
57
- from models.PairCls_GCN import PairCls as BONENET
58
- from models.SKINNING import SKINNET
59
- except ImportError as e:
60
- print(f"Error importing RigNet: {e}")
61
-
62
- # --- 3. Load Models ---
63
- device = torch.device("cpu")
64
- print(f"Loading RigNet models on {device}...")
65
-
66
- def load_checkpoint(model, filename):
67
- filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
68
- # map_location=device is crucial for CPU-only spaces
69
- checkpoint = torch.load(filepath, map_location=device)
70
- model.load_state_dict(checkpoint['state_dict'])
71
- return model
72
-
73
- # Initialize models
74
- jointNet = JOINTNET().to(device); jointNet.eval()
75
- load_checkpoint(jointNet, 'gcn_meanshift/model_best.pth.tar')
 
 
 
 
 
 
76
 
77
- rootNet = ROOTNET().to(device); rootNet.eval()
78
- load_checkpoint(rootNet, 'rootnet/model_best.pth.tar')
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- boneNet = BONENET().to(device); boneNet.eval()
81
- load_checkpoint(boneNet, 'bonenet/model_best.pth.tar')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True).to(device); skinNet.eval()
84
- load_checkpoint(skinNet, 'skinnet/model_best.pth.tar')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- print("Models loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # --- 4. Wrapper to fix Data object ---
89
- def create_single_data_patched(mesh_filename):
90
- # Call original function
91
- data, vox, surf_geo, t_norm, s_norm = original_create_single_data(mesh_filename)
92
-
93
- # Convert to our flexible class
94
- new_data = RigNetData(
95
- x=data.x,
96
- pos=data.pos,
97
- tpl_edge_index=data.tpl_edge_index,
98
- geo_edge_index=data.geo_edge_index,
99
- batch=data.batch
100
- )
101
- return new_data, vox, surf_geo, t_norm, s_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # --- 5. Inference Pipeline ---
104
- def rignet_inference(input_mesh_path):
105
- if not input_mesh_path: return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- working_dir = os.path.dirname(input_mesh_path)
108
- base_name = os.path.basename(input_mesh_path).replace(".obj", "")
109
- mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
110
 
111
- print(f"Processing: {input_mesh_path}")
 
 
 
 
 
 
 
 
 
 
112
  try:
113
- # 1. Preprocess
114
- mesh_ori = o3d.io.read_triangle_mesh(input_mesh_path)
115
- if len(np.asarray(mesh_ori.vertices)) == 0: raise ValueError("Empty mesh")
 
 
 
 
 
 
 
 
 
 
 
116
  mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
117
- o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
118
-
119
- # 2. Create Data (Patched)
120
- data, vox, surface_geodesic, t_norm, s_norm = create_single_data_patched(mesh_filename)
121
- data = data.to(device)
122
-
123
- # 3. Inference
124
- mesh_norm = mesh_filename.replace("_remesh.obj", "_normalized.obj")
125
 
 
126
  print("Predicting joints...")
127
- data = predict_joints(data, vox, jointNet, 1e-5, bandwidth=0.0429, mesh_filename=mesh_norm)
128
- data = data.to(device)
129
-
130
- print("Predicting connectivity...")
131
- skel = predict_skeleton(data, vox, rootNet, boneNet, mesh_filename=mesh_norm)
132
-
133
- print("Predicting skinning...")
134
- rig = predict_skinning(data, skel, skinNet, surface_geodesic, mesh_norm, subsampling=True)
135
-
136
- # 4. Export
137
- rig.normalize(s_norm, -t_norm)
138
- final_rig = tranfer_to_ori_mesh(input_mesh_path, mesh_filename, rig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- out_path = os.path.join(working_dir, f"{base_name}_rig.txt")
141
- final_rig.save(out_path)
142
- return out_path
143
-
144
  except Exception as e:
145
- print(f"Error: {e}")
146
- raise gr.Error(f"Processing failed: {str(e)}")
 
 
147
 
148
- # --- 6. Launch ---
149
- iface = gr.Interface(
150
- fn=rignet_inference,
151
- inputs=gr.Model3D(label="Input .obj"),
152
- outputs=gr.File(label="Rig Output .txt"),
153
- title="RigNet Demo",
154
- description="Upload a mesh to generate a rig. (Uses PyTorch 2.1 CPU)"
155
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  if __name__ == "__main__":
158
- iface.launch()
 
 
1
+ """
2
+ RigNet Gradio Demo
3
+ Automatic rigging for 3D character models
4
+ Based on: https://github.com/zhan-xu/RigNet
5
+ """
6
+
7
  import os
8
  import sys
9
+ import gradio as gr
 
10
  import torch
11
  import numpy as np
12
  import open3d as o3d
13
+ import trimesh
14
+ from pathlib import Path
15
+ import tempfile
16
+ import shutil
17
  from torch_geometric.data import Data
18
+ from torch_geometric.utils import add_self_loops
19
 
20
+ # Import RigNet modules (assuming they're in the same directory structure)
21
+ from utils import binvox_rw
22
+ from utils.rig_parser import Info
23
+ from utils.io_utils import assemble_skel_skin
24
+ from utils.cluster_utils import meanshift_cluster, nms_meanshift
25
+ from utils.mst_utils import inside_check, flip, increase_cost_for_outside_bone, primMST_symmetry, loadSkel_recur
26
+ from geometric_proc.common_ops import get_bones, calc_surface_geodesic
27
+ from gen_dataset import get_tpl_edges, get_geo_edges
28
+ from mst_generate import sample_on_bone, getInitId
29
+ from run_skinning import post_filter
30
+ from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
31
+ from models.ROOT_GCN import ROOTNET
32
+ from models.PairCls_GCN import PairCls as BONENET
33
+ from models.SKINNING import SKINNET
34
+ import itertools as it
35
+ from geometric_proc.compute_volumetric_geodesic import pts2line, calc_pts2bone_visible_mat
 
 
36
 
37
+ # Global variables
38
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
+ models_loaded = False
40
+ jointNet, rootNet, boneNet, skinNet = None, None, None, None
41
 
42
+ def load_models():
43
+ """Load all pre-trained RigNet models"""
44
+ global jointNet, rootNet, boneNet, skinNet, models_loaded
45
+
46
+ if models_loaded:
47
+ return
48
+
49
+ print("Loading RigNet models...")
50
+
51
+ # Joint prediction network
52
+ jointNet = JOINTNET()
53
+ jointNet.to(device)
54
+ jointNet.eval()
55
+ jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar',
56
+ map_location=device)
57
+ jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
58
+ print("✓ Joint prediction network loaded")
59
+
60
+ # Root prediction network
61
+ rootNet = ROOTNET()
62
+ rootNet.to(device)
63
+ rootNet.eval()
64
+ rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar',
65
+ map_location=device)
66
+ rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
67
+ print(" Root prediction network loaded")
68
+
69
+ # Bone connection network
70
+ boneNet = BONENET()
71
+ boneNet.to(device)
72
+ boneNet.eval()
73
+ boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar',
74
+ map_location=device)
75
+ boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
76
+ print("✓ Connection prediction network loaded")
77
+
78
+ # Skinning prediction network
79
+ skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
80
+ skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar',
81
+ map_location=device)
82
+ skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
83
+ skinNet.to(device)
84
+ skinNet.eval()
85
+ print("✓ Skinning prediction network loaded")
86
+
87
+ models_loaded = True
88
+ print("All models loaded successfully!")
89
 
90
+ def normalize_obj(mesh_v):
91
+ """Normalize mesh vertices to unit scale"""
92
+ dims = [max(mesh_v[:, 0]) - min(mesh_v[:, 0]),
93
+ max(mesh_v[:, 1]) - min(mesh_v[:, 1]),
94
+ max(mesh_v[:, 2]) - min(mesh_v[:, 2])]
95
+ scale = 1.0 / max(dims)
96
+ pivot = np.array([(min(mesh_v[:, 0]) + max(mesh_v[:, 0])) / 2,
97
+ min(mesh_v[:, 1]),
98
+ (min(mesh_v[:, 2]) + max(mesh_v[:, 2])) / 2])
99
+ mesh_v[:, 0] -= pivot[0]
100
+ mesh_v[:, 1] -= pivot[1]
101
+ mesh_v[:, 2] -= pivot[2]
102
+ mesh_v *= scale
103
+ return mesh_v, pivot, scale
104
 
105
+ def create_single_data(mesh_filename):
106
+ """Create input data for the network"""
107
+ mesh = o3d.io.read_triangle_mesh(mesh_filename)
108
+ mesh.compute_vertex_normals()
109
+ mesh_v = np.asarray(mesh.vertices)
110
+ mesh_vn = np.asarray(mesh.vertex_normals)
111
+ mesh_f = np.asarray(mesh.triangles)
112
+
113
+ mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
114
+
115
+ # Save normalized mesh
116
+ mesh_normalized = o3d.geometry.TriangleMesh(
117
+ vertices=o3d.utility.Vector3dVector(mesh_v),
118
+ triangles=o3d.utility.Vector3iVector(mesh_f))
119
+ normalized_path = mesh_filename.replace("_remesh.obj", "_normalized.obj")
120
+ o3d.io.write_triangle_mesh(normalized_path, mesh_normalized)
121
+
122
+ # Vertices
123
+ v = np.concatenate((mesh_v, mesh_vn), axis=1)
124
+ v = torch.from_numpy(v).float()
125
+
126
+ # Topology edges
127
+ print(" Gathering topological edges...")
128
+ tpl_e = get_tpl_edges(mesh_v, mesh_f).T
129
+ tpl_e = torch.from_numpy(tpl_e).long()
130
+ tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
131
+
132
+ # Surface geodesic distance matrix
133
+ print(" Calculating surface geodesic matrix...")
134
+ surface_geodesic = calc_surface_geodesic(mesh)
135
+
136
+ # Geodesic edges
137
+ print(" Gathering geodesic edges...")
138
+ geo_e = get_geo_edges(surface_geodesic, mesh_v).T
139
+ geo_e = torch.from_numpy(geo_e).long()
140
+ geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
141
+
142
+ # Batch
143
+ batch = torch.zeros(len(v), dtype=torch.long)
144
+
145
+ # Voxelization
146
+ vox_path = mesh_filename.replace('_remesh.obj', '_normalized.binvox')
147
+ if not os.path.exists(vox_path):
148
+ # Use binvox command
149
+ if sys.platform == "linux" or sys.platform == "linux2":
150
+ os.system(f"./binvox -d 88 -pb {normalized_path}")
151
+ elif sys.platform == "win32":
152
+ os.system(f"binvox.exe -d 88 {normalized_path}")
153
+
154
+ with open(vox_path, 'rb') as fvox:
155
+ vox = binvox_rw.read_as_3d_array(fvox)
156
+
157
+ data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e,
158
+ geo_edge_index=geo_e, batch=batch)
159
+
160
+ return data, vox, surface_geodesic, translation_normalize, scale_normalize
161
 
162
+ def predict_joints(input_data, vox, threshold=1e-5, bandwidth=None):
163
+ """Predict skeleton joints"""
164
+ data_displacement, _, attn_pred, bandwidth_pred = jointNet(input_data)
165
+ y_pred = data_displacement + input_data.pos
166
+ y_pred_np = y_pred.data.cpu().numpy()
167
+ attn_pred_np = attn_pred.data.cpu().numpy()
168
+
169
+ y_pred_np, index_inside = inside_check(y_pred_np, vox)
170
+ attn_pred_np = attn_pred_np[index_inside, :]
171
+ y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3]
172
+ attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3]
173
+
174
+ # Symmetrize points by reflecting
175
+ y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]])
176
+ y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0)
177
+ attn_pred_np = np.tile(attn_pred_np, (2, 1))
178
+
179
+ if bandwidth is None:
180
+ bandwidth = bandwidth_pred.item()
181
+
182
+ y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
183
+
184
+ Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
185
+ density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
186
+ density = np.sum(density, axis=0)
187
+ density_sum = np.sum(density)
188
+
189
+ y_pred_np = y_pred_np[density / density_sum > threshold]
190
+ attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0]
191
+ density = density[density / density_sum > threshold]
192
+
193
+ pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
194
+ pred_joints, _ = flip(pred_joints)
195
+
196
+ # Prepare pair-wise bone data
197
+ pairs = list(it.combinations(range(pred_joints.shape[0]), 2))
198
+ pair_attr = []
199
+ for pr in pairs:
200
+ dist = np.linalg.norm(pred_joints[pr[0]] - pred_joints[pr[1]])
201
+ bone_samples = sample_on_bone(pred_joints[pr[0]], pred_joints[pr[1]])
202
+ bone_samples_inside, _ = inside_check(bone_samples, vox)
203
+ outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
204
+ attr = np.array([dist, outside_proportion, 1])
205
+ pair_attr.append(attr)
206
+
207
+ pairs = np.array(pairs)
208
+ pair_attr = np.array(pair_attr)
209
+ pairs = torch.from_numpy(pairs).float()
210
+ pair_attr = torch.from_numpy(pair_attr).float()
211
+ pred_joints = torch.from_numpy(pred_joints).float()
212
+
213
+ joints_batch = torch.zeros(len(pred_joints), dtype=torch.long)
214
+ pairs_batch = torch.zeros(len(pairs), dtype=torch.long)
215
+
216
+ input_data.joints = pred_joints
217
+ input_data.pairs = pairs
218
+ input_data.pair_attr = pair_attr
219
+ input_data.joints_batch = joints_batch
220
+ input_data.pairs_batch = pairs_batch
221
+
222
+ return input_data
223
 
224
+ def predict_skeleton(input_data, vox):
225
+ """Predict skeleton structure"""
226
+ root_id = getInitId(input_data, rootNet)
227
+ pred_joints = input_data.joints.data.cpu().numpy()
228
+
229
+ with torch.no_grad():
230
+ connect_prob, _ = boneNet(input_data, permute_joints=False)
231
+ connect_prob = torch.sigmoid(connect_prob)
232
+
233
+ pair_idx = input_data.pairs.long().data.cpu().numpy()
234
+ prob_matrix = np.zeros((len(input_data.joints), len(input_data.joints)))
235
+ prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
236
+ prob_matrix = prob_matrix + prob_matrix.transpose()
237
+
238
+ cost_matrix = -np.log(prob_matrix + 1e-10)
239
+ cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
240
+
241
+ pred_skel = Info()
242
+ parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
243
+
244
+ for i in range(len(parent)):
245
+ if parent[i] == -1:
246
+ from utils.tree_utils import TreeNode
247
+ pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
248
+ break
249
+
250
+ loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
251
+ pred_skel.joint_pos = pred_skel.get_joint_dict()
252
+
253
+ return pred_skel
254
 
255
+ def calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=True):
256
+ """Calculate volumetric geodesic distance from vertices to bones"""
257
+ if subsampling:
258
+ mesh0 = o3d.io.read_triangle_mesh(mesh_filename)
259
+ mesh0 = mesh0.simplify_quadric_decimation(3000)
260
+ simplified_path = mesh_filename.replace(".obj", "_simplified.obj")
261
+ o3d.io.write_triangle_mesh(simplified_path, mesh0)
262
+ mesh_trimesh = trimesh.load(simplified_path)
263
+
264
+ subsamples_ids = np.random.choice(len(mesh_v), np.min((len(mesh_v), 1500)), replace=False)
265
+ subsamples = mesh_v[subsamples_ids, :]
266
+ surface_geodesic = surface_geodesic[subsamples_ids, :][:, subsamples_ids]
267
+ else:
268
+ mesh_trimesh = trimesh.load(mesh_filename)
269
+ subsamples = mesh_v
270
+
271
+ origins, ends, pts_bone_dist = pts2line(subsamples, bones)
272
+ pts_bone_visibility = calc_pts2bone_visible_mat(mesh_trimesh, origins, ends)
273
+ pts_bone_visibility = pts_bone_visibility.reshape(len(bones), len(subsamples)).transpose()
274
+ pts_bone_dist = pts_bone_dist.reshape(len(bones), len(subsamples)).transpose()
275
+
276
+ # Remove visible points which are too far
277
+ for b in range(pts_bone_visibility.shape[1]):
278
+ visible_pts = np.argwhere(pts_bone_visibility[:, b] == 1).squeeze(1)
279
+ if len(visible_pts) == 0:
280
+ continue
281
+ threshold_b = np.percentile(pts_bone_dist[visible_pts, b], 15)
282
+ pts_bone_visibility[pts_bone_dist[:, b] > 1.3 * threshold_b, b] = False
283
+
284
+ visible_matrix = np.zeros(pts_bone_visibility.shape)
285
+ visible_matrix[np.where(pts_bone_visibility == 1)] = pts_bone_dist[np.where(pts_bone_visibility == 1)]
286
+
287
+ for c in range(visible_matrix.shape[1]):
288
+ unvisible_pts = np.argwhere(pts_bone_visibility[:, c] == 0).squeeze(1)
289
+ visible_pts = np.argwhere(pts_bone_visibility[:, c] == 1).squeeze(1)
290
+
291
+ if len(visible_pts) == 0:
292
+ visible_matrix[:, c] = pts_bone_dist[:, c]
293
+ continue
294
+
295
+ for r in unvisible_pts:
296
+ dist1 = np.min(surface_geodesic[r, visible_pts])
297
+ nn_visible = visible_pts[np.argmin(surface_geodesic[r, visible_pts])]
298
+ if np.isinf(dist1):
299
+ visible_matrix[r, c] = 8.0 + pts_bone_dist[r, c]
300
+ else:
301
+ visible_matrix[r, c] = dist1 + visible_matrix[nn_visible, c]
302
+
303
+ if subsampling:
304
+ nn_dist = np.sum((mesh_v[:, np.newaxis, :] - subsamples[np.newaxis, ...]) ** 2, axis=2)
305
+ nn_ind = np.argmin(nn_dist, axis=1)
306
+ visible_matrix = visible_matrix[nn_ind, :]
307
+ os.remove(simplified_path)
308
+
309
+ return visible_matrix
310
 
311
+ def predict_skinning(input_data, pred_skel, surface_geodesic, mesh_filename, subsampling=True):
312
+ """Predict skinning weights"""
313
+ num_nearest_bone = 5
314
+ bones, bone_names, bone_isleaf = get_bones(pred_skel)
315
+ mesh_v = input_data.pos.data.cpu().numpy()
316
+
317
+ print(" Calculating volumetric geodesic distance...")
318
+ geo_dist = calc_geodesic_matrix(bones, mesh_v, surface_geodesic, mesh_filename, subsampling=subsampling)
319
+
320
+ input_samples = []
321
+ loss_mask = []
322
+ skin_nn = []
323
+
324
+ for v_id in range(len(mesh_v)):
325
+ geo_dist_v = geo_dist[v_id]
326
+ bone_id_near_to_far = np.argsort(geo_dist_v)
327
+ this_sample = []
328
+ this_nn = []
329
+ this_mask = []
330
+
331
+ for i in range(num_nearest_bone):
332
+ if i >= len(bones):
333
+ this_sample += bones[bone_id_near_to_far[0]].tolist()
334
+ this_sample.append(1.0 / (geo_dist_v[bone_id_near_to_far[0]] + 1e-10))
335
+ this_sample.append(bone_isleaf[bone_id_near_to_far[0]])
336
+ this_nn.append(0)
337
+ this_mask.append(0)
338
+ else:
339
+ skel_bone_id = bone_id_near_to_far[i]
340
+ this_sample += bones[skel_bone_id].tolist()
341
+ this_sample.append(1.0 / (geo_dist_v[skel_bone_id] + 1e-10))
342
+ this_sample.append(bone_isleaf[skel_bone_id])
343
+ this_nn.append(skel_bone_id)
344
+ this_mask.append(1)
345
+
346
+ input_samples.append(np.array(this_sample)[np.newaxis, :])
347
+ skin_nn.append(np.array(this_nn)[np.newaxis, :])
348
+ loss_mask.append(np.array(this_mask)[np.newaxis, :])
349
+
350
+ skin_input = np.concatenate(input_samples, axis=0)
351
+ loss_mask = np.concatenate(loss_mask, axis=0)
352
+ skin_nn = np.concatenate(skin_nn, axis=0)
353
+
354
+ skin_input = torch.from_numpy(skin_input).float()
355
+ input_data.skin_input = skin_input
356
+ input_data.to(device)
357
+
358
+ skin_pred = skinNet(input_data)
359
+ skin_pred = torch.softmax(skin_pred, dim=1)
360
+ skin_pred = skin_pred.data.cpu().numpy()
361
+ skin_pred = skin_pred * loss_mask
362
+
363
+ skin_nn = skin_nn[:, 0:num_nearest_bone]
364
+ skin_pred_full = np.zeros((len(skin_pred), len(bone_names)))
365
 
366
+ for v in range(len(skin_pred)):
367
+ for nn_id in range(len(skin_nn[v, :])):
368
+ skin_pred_full[v, skin_nn[v, nn_id]] = skin_pred[v, nn_id]
369
 
370
+ print(" Filtering skinning prediction...")
371
+ tpl_e = input_data.tpl_edge_index.data.cpu().numpy()
372
+ skin_pred_full = post_filter(skin_pred_full, tpl_e, num_ring=1)
373
+ skin_pred_full[skin_pred_full < np.max(skin_pred_full, axis=1, keepdims=True) * 0.35] = 0.0
374
+ skin_pred_full = skin_pred_full / (skin_pred_full.sum(axis=1, keepdims=True) + 1e-10)
375
+
376
+ skel_res = assemble_skel_skin(pred_skel, skin_pred_full)
377
+ return skel_res
378
+
379
+ def process_model(input_file, bandwidth_val, threshold_val):
380
+ """Main processing function for Gradio interface"""
381
  try:
382
+ # Load models if not already loaded
383
+ load_models()
384
+
385
+ # Create temporary directory for processing
386
+ temp_dir = tempfile.mkdtemp()
387
+
388
+ # Copy input file
389
+ input_path = Path(input_file.name)
390
+ temp_input = os.path.join(temp_dir, "input_ori.obj")
391
+ shutil.copy(input_path, temp_input)
392
+
393
+ # Remesh the input
394
+ print("Preprocessing: Remeshing input...")
395
+ mesh_ori = o3d.io.read_triangle_mesh(temp_input)
396
  mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
397
+ temp_remesh = os.path.join(temp_dir, "input_remesh.obj")
398
+ o3d.io.write_triangle_mesh(temp_remesh, mesh_remesh)
399
+
400
+ # Create data
401
+ print("Creating data...")
402
+ data, vox, surface_geodesic, translation, scale = create_single_data(temp_remesh)
403
+ data.to(device)
 
404
 
405
+ # Predict joints
406
  print("Predicting joints...")
407
+ data = predict_joints(data, vox, threshold=threshold_val, bandwidth=bandwidth_val)
408
+ data.to(device)
409
+
410
+ # Predict skeleton
411
+ print("Predicting skeleton connectivity...")
412
+ pred_skeleton = predict_skeleton(data, vox)
413
+
414
+ # Predict skinning
415
+ print("Predicting skinning weights...")
416
+ normalized_mesh = temp_remesh.replace("_remesh.obj", "_normalized.obj")
417
+ pred_rig = predict_skinning(data, pred_skeleton, surface_geodesic,
418
+ normalized_mesh, subsampling=True)
419
+
420
+ # Denormalize
421
+ pred_rig.normalize(scale, -translation)
422
+
423
+ # Save result
424
+ output_file = os.path.join(temp_dir, "output_rig.txt")
425
+ pred_rig.save(output_file)
426
+
427
+ print("✓ Processing complete!")
428
+
429
+ # Create info message
430
+ num_joints = len(pred_rig.joint_pos)
431
+ info_msg = f"Successfully generated rig with {num_joints} joints!"
432
+
433
+ return output_file, info_msg
434
 
 
 
 
 
435
  except Exception as e:
436
+ import traceback
437
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
438
+ print(error_msg)
439
+ return None, error_msg
440
 
441
+ # Gradio Interface
442
+ def create_demo():
443
+ """Create Gradio interface"""
444
+
445
+ with gr.Blocks(title="RigNet: Automatic 3D Character Rigging") as demo:
446
+ gr.Markdown("""
447
+ # 🎮 RigNet: Automatic Character Rigging
448
+
449
+ Upload a 3D character model (OBJ format) and automatically generate a skeletal rig with skinning weights.
450
+
451
+ **Based on:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://github.com/zhan-xu/RigNet)
452
+
453
+ ### Instructions:
454
+ 1. Upload your 3D character mesh in OBJ format
455
+ 2. Adjust parameters if needed (default values work for most cases)
456
+ 3. Click "Generate Rig" and wait for processing
457
+ 4. Download the generated rig file
458
+
459
+ **Note:** For best results, simplify your mesh to 1K-5K vertices before uploading.
460
+ """)
461
+
462
+ with gr.Row():
463
+ with gr.Column():
464
+ input_mesh = gr.File(label="Upload 3D Model (.obj)", file_types=[".obj"])
465
+
466
+ with gr.Accordion("Advanced Parameters", open=False):
467
+ bandwidth = gr.Slider(
468
+ minimum=0.02, maximum=0.08, value=0.0429, step=0.001,
469
+ label="Bandwidth (for joint clustering)",
470
+ info="Default: 0.0429. Adjust if joint prediction is too dense/sparse"
471
+ )
472
+ threshold = gr.Slider(
473
+ minimum=0.1e-5, maximum=5e-5, value=1e-5, step=0.1e-5,
474
+ label="Density Threshold",
475
+ info="Default: 1e-5. Higher values = fewer joints"
476
+ )
477
+
478
+ process_btn = gr.Button("🚀 Generate Rig", variant="primary", size="lg")
479
+
480
+ with gr.Column():
481
+ output_file = gr.File(label="Download Rig Output (.txt)")
482
+ status_msg = gr.Textbox(label="Status", lines=3)
483
+
484
+ gr.Markdown("""
485
+ ### Output Format:
486
+ The generated `.txt` file contains:
487
+ - **Joint definitions:** Position of each joint in 3D space
488
+ - **Hierarchy:** Parent-child relationships between joints
489
+ - **Skinning weights:** How each vertex is influenced by nearby joints
490
+
491
+ ### Next Steps:
492
+ - Import the mesh and rig file into animation software (Maya, Blender, etc.)
493
+ - Use provided scripts (e.g., `maya_save_fbx.py`) to convert to FBX format
494
+ - Start animating your character!
495
+
496
+ ---
497
+ **References:**
498
+ - [RigNet Paper](https://arxiv.org/abs/2005.00559)
499
+ - [GitHub Repository](https://github.com/zhan-xu/RigNet)
500
+ - [Project Page](https://zhan-xu.github.io/rig-net/)
501
+ """)
502
+
503
+ # Event handler
504
+ process_btn.click(
505
+ fn=process_model,
506
+ inputs=[input_mesh, bandwidth, threshold],
507
+ outputs=[output_file, status_msg]
508
+ )
509
+
510
+ return demo
511
 
512
  if __name__ == "__main__":
513
+ demo = create_demo()
514
+ demo.launch(share=False)