Chang, Kai-Ching commited on
Commit
2ce0171
·
1 Parent(s): 058fe77
Files changed (2) hide show
  1. app.py +63 -73
  2. requirements.txt +6 -23
app.py CHANGED
@@ -14,7 +14,6 @@ if RIGNET_PATH not in sys.path:
14
  sys.path.append(RIGNET_PATH)
15
 
16
  # Ensure binvox executable is available in the root (required by RigNet's os.system call)
17
- # RigNet code expects './binvox' in the current directory.
18
  BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
19
  if platform.system() == "Windows":
20
  BINVOX_SRC += ".exe"
@@ -29,10 +28,9 @@ if os.path.exists(BINVOX_SRC):
29
  if platform.system() != "Windows":
30
  os.system(f"chmod +x {BINVOX_DEST}")
31
  else:
32
- print(f"Warning: binvox not found at {BINVOX_SRC}. Voxelization may fail.")
33
 
34
  # --- 2. Import RigNet Modules ---
35
- # These imports must happen AFTER adding RigNet to sys.path
36
  try:
37
  from quick_start import (
38
  create_single_data, predict_joints, predict_skeleton,
@@ -43,7 +41,7 @@ try:
43
  from models.PairCls_GCN import PairCls as BONENET
44
  from models.SKINNING import SKINNET
45
  except ImportError as e:
46
- print("Error importing RigNet modules. Ensure the 'RigNet' folder is present and contains the code.")
47
  raise e
48
 
49
  # --- 3. Load Models Globally ---
@@ -53,6 +51,8 @@ print(f"Loading RigNet models on {device}...")
53
  def load_checkpoint(model, filename):
54
  # Checkpoints are located inside RigNet/checkpoints/
55
  filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
 
 
56
  checkpoint = torch.load(filepath, map_location=device)
57
  model.load_state_dict(checkpoint['state_dict'])
58
  return model
@@ -84,89 +84,79 @@ def rignet_inference(input_mesh_path):
84
  if not input_mesh_path:
85
  return None
86
 
87
- # Setup temporary paths
88
- # Gradio stores inputs in a specific hash folder.
89
- # We work in that folder or create a clean temp one to avoid side effects.
90
  working_dir = os.path.dirname(input_mesh_path)
91
  base_name = os.path.basename(input_mesh_path).replace(".obj", "")
92
 
93
- # Prepare filenames (mimicking quick_start naming convention)
94
  mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
95
- mesh_original_path = input_mesh_path
96
 
97
  print(f"Processing: {input_mesh_path}")
98
 
99
- # 1. Preprocessing / Decimation
100
- mesh_ori = o3d.io.read_triangle_mesh(mesh_original_path)
101
- if len(np.asarray(mesh_ori.vertices)) == 0:
102
- raise ValueError("Empty mesh uploaded.")
103
-
104
- # Simplify mesh as per RigNet requirement (approx 4k vertices for inference)
105
- mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
106
- o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
107
-
108
- # 2. Create Data (Voxelization + Geodesic calculation)
109
- # create_single_data calls ./binvox internally. We ensured it exists in root.
110
- data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
111
- data.to(device)
112
-
113
- # 3. Pipeline Predictions
114
- # Config (using default learned parameters from quick_start)
115
- bandwidth = 0.0429
116
- threshold = 1e-5
117
-
118
- print("Predicting joints...")
119
- # Note: We pass the normalized obj path which create_single_data generated
120
- mesh_norm_path = mesh_filename.replace("_remesh.obj", "_normalized.obj")
121
-
122
- data = predict_joints(
123
- data, vox, jointNet, threshold, bandwidth=bandwidth,
124
- mesh_filename=mesh_norm_path
125
- )
126
- data.to(device)
127
-
128
- print("Predicting connectivity...")
129
- pred_skeleton = predict_skeleton(
130
- data, vox, rootNet, boneNet,
131
- mesh_filename=mesh_norm_path
132
- )
133
-
134
- print("Predicting skinning...")
135
- # downsample_skinning=True for speed (as per quick_start default)
136
- pred_rig = predict_skinning(
137
- data, pred_skeleton, skinNet, surface_geodesic,
138
- mesh_norm_path, subsampling=True
139
- )
140
-
141
- # 4. Post-processing
142
- # Reverse normalization
143
- pred_rig.normalize(scale_normalize, -translation_normalize)
144
-
145
- # Transfer rig to original high-res mesh
146
- final_rig = tranfer_to_ori_mesh(mesh_original_path, mesh_filename, pred_rig)
147
-
148
- # Save output
149
- output_path = os.path.join(working_dir, f"{base_name}_rig.txt")
150
- final_rig.save(output_path)
151
-
152
- # Optional: Return the simplified mesh for visualization if needed,
153
- # but here we just return the RigNet text file output.
154
- return output_path
155
 
156
  # --- 5. Gradio Interface ---
157
- title = "RigNet: Neural Rigging for Articulated Characters"
158
- description = (
159
- "Upload a 3D model (.obj) to automatically generate a skeleton and skinning weights. "
160
- "The output is a text file containing the rig information."
161
- )
162
 
163
  iface = gr.Interface(
164
  fn=rignet_inference,
165
  inputs=gr.Model3D(label="Input Mesh (.obj)"),
166
- outputs=gr.File(label="Download Rig Output (.txt)"),
167
  title=title,
168
- description=description,
169
- examples=[], # Add examples if you have sample .obj files in a folder
170
  )
171
 
172
  if __name__ == "__main__":
 
14
  sys.path.append(RIGNET_PATH)
15
 
16
  # Ensure binvox executable is available in the root (required by RigNet's os.system call)
 
17
  BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
18
  if platform.system() == "Windows":
19
  BINVOX_SRC += ".exe"
 
28
  if platform.system() != "Windows":
29
  os.system(f"chmod +x {BINVOX_DEST}")
30
  else:
31
+ print(f"Warning: binvox not found at {BINVOX_SRC}. Inference may fail.")
32
 
33
  # --- 2. Import RigNet Modules ---
 
34
  try:
35
  from quick_start import (
36
  create_single_data, predict_joints, predict_skeleton,
 
41
  from models.PairCls_GCN import PairCls as BONENET
42
  from models.SKINNING import SKINNET
43
  except ImportError as e:
44
+ print("Error importing RigNet modules. Ensure the 'RigNet' folder is correctly placed.")
45
  raise e
46
 
47
  # --- 3. Load Models Globally ---
 
51
  def load_checkpoint(model, filename):
52
  # Checkpoints are located inside RigNet/checkpoints/
53
  filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
54
+ if not os.path.exists(filepath):
55
+ raise FileNotFoundError(f"Checkpoint not found: {filepath}")
56
  checkpoint = torch.load(filepath, map_location=device)
57
  model.load_state_dict(checkpoint['state_dict'])
58
  return model
 
84
  if not input_mesh_path:
85
  return None
86
 
87
+ # Work in a temp folder derived from the input path
 
 
88
  working_dir = os.path.dirname(input_mesh_path)
89
  base_name = os.path.basename(input_mesh_path).replace(".obj", "")
90
 
91
+ # Prepare filenames
92
  mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
 
93
 
94
  print(f"Processing: {input_mesh_path}")
95
 
96
+ try:
97
+ # 1. Preprocessing / Decimation
98
+ mesh_ori = o3d.io.read_triangle_mesh(input_mesh_path)
99
+ if len(np.asarray(mesh_ori.vertices)) == 0:
100
+ raise ValueError("Empty mesh uploaded.")
101
+
102
+ # Simplify mesh (approx 4k vertices for inference)
103
+ mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
104
+ o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
105
+
106
+ # 2. Create Data (Voxelization + Geodesic calculation)
107
+ data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
108
+ data = data.to(device)
109
+
110
+ # 3. Pipeline Predictions
111
+ # Default parameters from quick_start
112
+ bandwidth = 0.0429
113
+ threshold = 1e-5
114
+
115
+ mesh_norm_path = mesh_filename.replace("_remesh.obj", "_normalized.obj")
116
+
117
+ print("Predicting joints...")
118
+ data = predict_joints(
119
+ data, vox, jointNet, threshold, bandwidth=bandwidth,
120
+ mesh_filename=mesh_norm_path
121
+ )
122
+ data = data.to(device)
123
+
124
+ print("Predicting connectivity...")
125
+ pred_skeleton = predict_skeleton(
126
+ data, vox, rootNet, boneNet,
127
+ mesh_filename=mesh_norm_path
128
+ )
129
+
130
+ print("Predicting skinning...")
131
+ pred_rig = predict_skinning(
132
+ data, pred_skeleton, skinNet, surface_geodesic,
133
+ mesh_norm_path, subsampling=True
134
+ )
135
+
136
+ # 4. Post-processing
137
+ pred_rig.normalize(scale_normalize, -translation_normalize)
138
+ final_rig = tranfer_to_ori_mesh(input_mesh_path, mesh_filename, pred_rig)
139
+
140
+ # Save output to a text file
141
+ output_path = os.path.join(working_dir, f"{base_name}_rig.txt")
142
+ final_rig.save(output_path)
143
+
144
+ return output_path
145
+
146
+ except Exception as e:
147
+ print(f"Error during inference: {e}")
148
+ raise gr.Error(f"Processing failed: {str(e)}")
 
 
 
149
 
150
  # --- 5. Gradio Interface ---
151
+ title = "RigNet: Neural Rigging"
152
+ description = "Upload an .obj file. The model will generate a skeleton and skinning weights, returned as a text file."
 
 
 
153
 
154
  iface = gr.Interface(
155
  fn=rignet_inference,
156
  inputs=gr.Model3D(label="Input Mesh (.obj)"),
157
+ outputs=gr.File(label="Download Rig (.txt)"),
158
  title=title,
159
+ description=description
 
160
  )
161
 
162
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,30 +1,13 @@
1
- # Core dependencies
2
- gradio==4.16.0
3
  torch==1.12.0
4
  torchvision==0.13.0
5
-
6
- # PyTorch Geometric and related packages
 
 
 
7
  torch-geometric==1.7.2
8
  torch-scatter
9
  torch-sparse
10
  torch-cluster
11
  torch-spline-conv
12
-
13
- # Scientific computing
14
- numpy==1.21.6
15
- scipy==1.7.3
16
- scikit-learn
17
-
18
- # 3D processing
19
- trimesh[easy]==3.21.7
20
- open3d==0.9.0
21
- rtree>=0.8,<0.9
22
-
23
- # Visualization and utilities
24
- matplotlib==3.5.3
25
- opencv-python==4.7.0.72
26
- tensorboard==2.11.2
27
- Pillow==9.5.0
28
-
29
- # File handling
30
- pathlib
 
1
+ --find-links https://data.pyg.org/whl/torch-1.12.0+cpu.html
 
2
  torch==1.12.0
3
  torchvision==0.13.0
4
+ gradio
5
+ numpy
6
+ scipy
7
+ trimesh
8
+ open3d
9
  torch-geometric==1.7.2
10
  torch-scatter
11
  torch-sparse
12
  torch-cluster
13
  torch-spline-conv