ckc99u commited on
Commit
fab76dd
·
verified ·
1 Parent(s): f9a9164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -134
app.py CHANGED
@@ -1,124 +1,33 @@
1
  #!/usr/bin/env python3
2
- """
3
- RigNet Gradio Demo for Hugging Face Spaces (CPU)
4
- """
 
 
 
5
  import os
6
  import sys
7
  import tempfile
8
  import shutil
9
  import traceback
10
  from pathlib import Path
11
-
12
- # CRITICAL: Patch quick_start.py BEFORE importing
13
- # Fix binvox path issue - must happen before RigNet imports
14
- import_patch_code = '''
15
- import os
16
- from sys import platform
17
-
18
- # Monkey-patch quick_start.create_single_data to fix binvox paths
19
- original_create_single_data = None
20
-
21
- def patched_create_single_data(mesh_filename):
22
- """Patched version that fixes binvox path issues"""
23
- import open3d as o3d
24
- import numpy as np
25
- import torch
26
- from torch_geometric.data import Data
27
- from torch_geometric.utils import add_self_loops
28
- sys.path.insert(0, '/app/RigNet')
29
- from utils import binvox_rw
30
- from gen_dataset import get_tpl_edges, get_geo_edges
31
- from geometric_proc.common_ops import calc_surface_geodesic
32
- from quick_start import normalize_obj
33
-
34
- # Load and normalize mesh
35
- mesh = o3d.io.read_triangle_mesh(mesh_filename)
36
- mesh.compute_vertex_normals()
37
- mesh_v = np.asarray(mesh.vertices)
38
- mesh_vn = np.asarray(mesh.vertex_normals)
39
- mesh_f = np.asarray(mesh.triangles)
40
- mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
41
-
42
- # Save normalized mesh
43
- mesh_normalized = o3d.geometry.TriangleMesh(
44
- vertices=o3d.utility.Vector3dVector(mesh_v),
45
- triangles=o3d.utility.Vector3iVector(mesh_f)
46
- )
47
- normalized_obj = mesh_filename.replace("_remesh.obj", "_normalized.obj")
48
- o3d.io.write_triangle_mesh(normalized_obj, mesh_normalized)
49
-
50
- # Prepare data
51
- v = np.concatenate((mesh_v, mesh_vn), axis=1)
52
- v = torch.from_numpy(v).float()
53
-
54
- # Topology edges
55
- print(" gathering topological edges.")
56
- tpl_e = get_tpl_edges(mesh_v, mesh_f).T
57
- tpl_e = torch.from_numpy(tpl_e).long()
58
- tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
59
-
60
- # Surface geodesic
61
- print(" calculating surface geodesic matrix.")
62
- surface_geodesic = calc_surface_geodesic(mesh)
63
-
64
- # Geodesic edges
65
- print(" gathering geodesic edges.")
66
- geo_e = get_geo_edges(surface_geodesic, mesh_v).T
67
- geo_e = torch.from_numpy(geo_e).long()
68
- geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
69
-
70
- # Batch
71
- batch = torch.zeros(len(v), dtype=torch.long)
72
-
73
- # Voxelization - FIX THE PATH ISSUE HERE
74
- binvox_file = normalized_obj.replace('.obj', '.binvox')
75
-
76
- if not os.path.exists(binvox_file):
77
- print(f" Creating voxel file: {binvox_file}")
78
- # Use full path to binvox
79
- cmd = f"/usr/local/bin/binvox -d 88 -pb {normalized_obj}"
80
- print(f" Running: {cmd}")
81
- ret = os.system(cmd)
82
- if ret != 0:
83
- raise RuntimeError(f"binvox failed with return code {ret}")
84
-
85
- # Verify binvox file was created
86
- if not os.path.exists(binvox_file):
87
- raise FileNotFoundError(f"Binvox file not created: {binvox_file}")
88
-
89
- # Load voxel data
90
- with open(binvox_file, 'rb') as fvox:
91
- vox = binvox_rw.read_as_3d_array(fvox)
92
-
93
- data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e,
94
- geo_edge_index=geo_e, batch=batch)
95
-
96
- return data, vox, surface_geodesic, translation_normalize, scale_normalize
97
- '''
98
-
99
- # Execute the patch
100
- exec(import_patch_code)
101
-
102
- import gradio as gr
103
  import torch
104
- import numpy as np
105
 
106
  # Add RigNet to Python path
107
  sys.path.insert(0, '/app/RigNet')
108
 
109
  # Import RigNet modules
110
- from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
111
- from models.ROOT_GCN import ROOTNET
112
- from models.PairCls_GCN import PairCls as BONENET
113
- from models.SKINNING import SKINNET
114
-
115
- # Import other functions from quick_start (we'll use our patched version)
116
  from quick_start import (
 
117
  predict_joints,
118
  predict_skeleton,
119
  predict_skinning,
120
  normalize_obj
121
  )
 
 
 
 
122
 
123
  # Global variables for models
124
  device = torch.device("cpu")
@@ -148,7 +57,7 @@ def load_models():
148
  map_location=device
149
  )
150
  jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
151
- print(" Joint prediction network loaded")
152
 
153
  # Root prediction network
154
  rootNet = ROOTNET()
@@ -159,7 +68,7 @@ def load_models():
159
  map_location=device
160
  )
161
  rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
162
- print(" Root prediction network loaded")
163
 
164
  # Bone connection network
165
  boneNet = BONENET()
@@ -170,7 +79,7 @@ def load_models():
170
  map_location=device
171
  )
172
  boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
173
- print(" Connectivity prediction network loaded")
174
 
175
  # Skinning network
176
  skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
@@ -181,7 +90,7 @@ def load_models():
181
  skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
182
  skinNet.to(device)
183
  skinNet.eval()
184
- print(" Skinning prediction network loaded")
185
 
186
  models_loaded = True
187
  print("All models loaded successfully!\n")
@@ -203,12 +112,11 @@ def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True)
203
  shutil.copy(input_obj_path, mesh_filename)
204
 
205
  print(f"\nProcessing: {base_name}")
206
- print(f"Working directory: {work_dir}")
207
 
208
- # Step 1: Create data - USE PATCHED VERSION
209
  print(" [1/4] Creating input data...")
210
  data, vox, surface_geodesic, translation_normalize, scale_normalize = \
211
- patched_create_single_data(mesh_filename)
212
  data.to(device)
213
 
214
  # Step 2: Predict joints
@@ -242,7 +150,7 @@ def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True)
242
  output_rig_path = os.path.join(work_dir, f'{base_name}_rig.txt')
243
  pred_rig.save(output_rig_path)
244
 
245
- print(f" Successfully generated rig: {base_name}_rig.txt\n")
246
 
247
  return output_rig_path
248
 
@@ -254,32 +162,73 @@ def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True)
254
 
255
  def rignet_inference(input_obj, bandwidth, threshold):
256
  """
257
- Gradio inference function
258
  """
259
  print("\n" + "="*60)
260
- print("🔍 rignet_inference CALLED!")
261
  print(f" input_obj type: {type(input_obj)}")
 
 
 
262
 
 
263
  if input_obj is None:
264
- return None, "⚠️ Please upload an OBJ file first"
 
 
 
265
 
266
  try:
 
267
  load_models()
268
 
269
- # Extract file path
270
  input_path = None
 
 
271
  if hasattr(input_obj, 'name'):
272
  input_path = input_obj.name
 
 
 
273
  elif isinstance(input_obj, str):
274
  input_path = input_obj
275
- elif isinstance(input_obj, dict) and 'name' in input_obj:
276
- input_path = input_obj['name']
277
 
278
- if not input_path or not os.path.exists(input_path):
279
- return None, f"❌ Invalid file path: {input_path}"
 
 
 
 
 
280
 
281
- print(f" Processing: {input_path}")
282
- print(f" File size: {os.path.getsize(input_path):,} bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  print("="*60 + "\n")
284
 
285
  # Process the mesh
@@ -290,46 +239,147 @@ def rignet_inference(input_obj, bandwidth, threshold):
290
  downsample_skinning=True
291
  )
292
 
 
293
  if not os.path.exists(output_rig_path):
294
- return None, " Output file was not created"
 
 
295
 
296
  output_size = os.path.getsize(output_rig_path)
297
- status_msg = f" Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes"
298
 
 
299
  return output_rig_path, status_msg
300
 
301
  except Exception as e:
302
- error_msg = f" Error:\n\n{str(e)}\n\n{traceback.format_exc()}"
 
 
303
  print(error_msg)
 
304
  return None, error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if __name__ == "__main__":
308
- print("="*60)
309
- print("RigNet Gradio Demo - Starting...")
310
- print("="*60)
311
-
312
  load_models()
313
-
 
314
  demo = gr.Interface(
315
  fn=rignet_inference,
316
  inputs=[
317
  gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"),
318
- gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth"),
319
- gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)")
320
  ],
321
  outputs=[
322
  gr.File(label="Download Rig TXT"),
323
  gr.Textbox(label="Status", lines=5)
324
  ],
325
- title="🎭 RigNet: Neural Rigging for 3D Characters",
326
- description="Upload OBJ (1K-5K vertices recommended). Processing takes 1-3 minutes on CPU.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  allow_flagging="never"
328
  )
329
-
330
  demo.launch(
331
  server_name="0.0.0.0",
332
  server_port=7860,
 
333
  show_error=True,
334
  debug=True
335
- )
 
1
  #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ import gradio as gr
6
+ import trimesh
7
+ import numpy as np
8
  import os
9
  import sys
10
  import tempfile
11
  import shutil
12
  import traceback
13
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import torch
 
15
 
16
  # Add RigNet to Python path
17
  sys.path.insert(0, '/app/RigNet')
18
 
19
  # Import RigNet modules
 
 
 
 
 
 
20
  from quick_start import (
21
+ create_single_data,
22
  predict_joints,
23
  predict_skeleton,
24
  predict_skinning,
25
  normalize_obj
26
  )
27
+ from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
28
+ from models.ROOT_GCN import ROOTNET
29
+ from models.PairCls_GCN import PairCls as BONENET
30
+ from models.SKINNING import SKINNET
31
 
32
  # Global variables for models
33
  device = torch.device("cpu")
 
57
  map_location=device
58
  )
59
  jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
60
+ print("✓ Joint prediction network loaded")
61
 
62
  # Root prediction network
63
  rootNet = ROOTNET()
 
68
  map_location=device
69
  )
70
  rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
71
+ print("✓ Root prediction network loaded")
72
 
73
  # Bone connection network
74
  boneNet = BONENET()
 
79
  map_location=device
80
  )
81
  boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
82
+ print("✓ Connectivity prediction network loaded")
83
 
84
  # Skinning network
85
  skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
 
90
  skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
91
  skinNet.to(device)
92
  skinNet.eval()
93
+ print("✓ Skinning prediction network loaded")
94
 
95
  models_loaded = True
96
  print("All models loaded successfully!\n")
 
112
  shutil.copy(input_obj_path, mesh_filename)
113
 
114
  print(f"\nProcessing: {base_name}")
 
115
 
116
+ # Step 1: Create data
117
  print(" [1/4] Creating input data...")
118
  data, vox, surface_geodesic, translation_normalize, scale_normalize = \
119
+ create_single_data(mesh_filename)
120
  data.to(device)
121
 
122
  # Step 2: Predict joints
 
150
  output_rig_path = os.path.join(work_dir, f'{base_name}_rig.txt')
151
  pred_rig.save(output_rig_path)
152
 
153
+ print(f"✓ Successfully generated rig: {base_name}_rig.txt\n")
154
 
155
  return output_rig_path
156
 
 
162
 
163
  def rignet_inference(input_obj, bandwidth, threshold):
164
  """
165
+ Gradio inference function with extensive debugging
166
  """
167
  print("\n" + "="*60)
168
+ print("🔍 DEBUG: rignet_inference CALLED!")
169
  print(f" input_obj type: {type(input_obj)}")
170
+ print(f" input_obj value: {input_obj}")
171
+ print(f" bandwidth: {bandwidth}")
172
+ print(f" threshold: {threshold}")
173
 
174
+ # Check if input is None or empty
175
  if input_obj is None:
176
+ msg = "⚠️ Please upload an OBJ file first"
177
+ print(f" ERROR: {msg}")
178
+ print("="*60 + "\n")
179
+ return None, msg
180
 
181
  try:
182
+ # Ensure models are loaded
183
  load_models()
184
 
185
+ # Extract file path - handle multiple Gradio formats
186
  input_path = None
187
+
188
+ # Case 1: File object with .name attribute
189
  if hasattr(input_obj, 'name'):
190
  input_path = input_obj.name
191
+ print(f" ✓ Got path from .name: {input_path}")
192
+
193
+ # Case 2: Already a string path
194
  elif isinstance(input_obj, str):
195
  input_path = input_obj
196
+ print(f" ✓ Already a string path: {input_path}")
 
197
 
198
+ # Case 3: Dictionary with 'name' key
199
+ elif isinstance(input_obj, dict):
200
+ if 'name' in input_obj:
201
+ input_path = input_obj['name']
202
+ print(f" ✓ Got path from dict['name']: {input_path}")
203
+ else:
204
+ print(f" ERROR: Dict without 'name' key. Keys: {input_obj.keys()}")
205
 
206
+ # Case 4: Unknown type - debug it
207
+ else:
208
+ print(f" ERROR: Unknown input type!")
209
+ print(f" Attributes: {dir(input_obj)}")
210
+ if hasattr(input_obj, '__dict__'):
211
+ print(f" __dict__: {input_obj.__dict__}")
212
+ msg = f"❌ Unexpected file input type: {type(input_obj)}"
213
+ print("="*60 + "\n")
214
+ return None, msg
215
+
216
+ # Validate file path
217
+ if not input_path:
218
+ msg = "❌ Could not extract file path from input"
219
+ print(f" ERROR: {msg}")
220
+ print("="*60 + "\n")
221
+ return None, msg
222
+
223
+ if not os.path.exists(input_path):
224
+ msg = f"❌ File does not exist: {input_path}"
225
+ print(f" ERROR: {msg}")
226
+ print("="*60 + "\n")
227
+ return None, msg
228
+
229
+ file_size = os.path.getsize(input_path)
230
+ print(f" ✓ File validated: {input_path}")
231
+ print(f" ✓ File size: {file_size:,} bytes")
232
  print("="*60 + "\n")
233
 
234
  # Process the mesh
 
239
  downsample_skinning=True
240
  )
241
 
242
+ # Validate output
243
  if not os.path.exists(output_rig_path):
244
+ msg = "❌ Output file was not created"
245
+ print(f"ERROR: {msg}")
246
+ return None, msg
247
 
248
  output_size = os.path.getsize(output_rig_path)
249
+ status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes"
250
 
251
+ print(f"✓ SUCCESS! Returning output file")
252
  return output_rig_path, status_msg
253
 
254
  except Exception as e:
255
+ error_msg = f"❌ Error during processing:\n\n{str(e)}\n\nDetails:\n{traceback.format_exc()}"
256
+ print("\n" + "="*60)
257
+ print("❌ EXCEPTION CAUGHT:")
258
  print(error_msg)
259
+ print("="*60 + "\n")
260
  return None, error_msg
261
+ def process_obj_file(file_obj):
262
+ """
263
+ Process OBJ file and return first 10 lines of analysis results
264
+ """
265
+ sys.stdout.flush()
266
+ print(f"[DEBUG] Processing file: {file_obj.name if file_obj else 'None'}", flush=True)
267
+
268
+ if not file_obj:
269
+ return "⚠️ No file provided"
270
+
271
+ try:
272
+ results = []
273
+ results.append("="*60)
274
+ results.append("OBJ FILE ANALYSIS - First 10 Lines of Results")
275
+ results.append("="*60)
276
+
277
+ # Read raw OBJ file first 10 lines
278
+ results.append("\n📄 RAW OBJ FILE (First 10 Lines):")
279
+ results.append("-"*60)
280
+ with open(file_obj.name, 'r') as f:
281
+ for i, line in enumerate(f):
282
+ if i >= 10:
283
+ break
284
+ results.append(f"Line {i+1}: {line.rstrip()}")
285
+
286
+ # Load mesh using trimesh
287
+ results.append("\n🔷 MESH ANALYSIS:")
288
+ results.append("-"*60)
289
+
290
+ mesh = trimesh.load(file_obj.name, force='mesh')
291
+
292
+ # Check if it's a Scene or Mesh
293
+ if isinstance(mesh, trimesh.Scene):
294
+ results.append(f"Type: Scene with {len(mesh.geometry)} geometries")
295
+ # Get the first geometry
296
+ if len(mesh.geometry) > 0:
297
+ first_geom_name = list(mesh.geometry.keys())[0]
298
+ mesh = mesh.geometry[first_geom_name]
299
+ results.append(f"Using first geometry: {first_geom_name}")
300
+
301
+ # Mesh statistics (ensures we don't exceed 10 total result lines)
302
+ results.append(f"Vertices: {len(mesh.vertices)}")
303
+ results.append(f"Faces: {len(mesh.faces)}")
304
+ results.append(f"Is Watertight: {mesh.is_watertight}")
305
+ results.append(f"Is Winding Consistent: {mesh.is_winding_consistent}")
306
+ results.append(f"Bounds: {mesh.bounds.tolist()}")
307
+ results.append(f"Center Mass: {mesh.center_mass.tolist()}")
308
+
309
+ # Join results
310
+ output = "\n".join(results[:25]) # Limit output
311
+
312
+ print("[DEBUG] Processing completed successfully", flush=True)
313
+ return output
314
+
315
+ except Exception as e:
316
+ error_msg = f"❌ Error processing file: {str(e)}\n\nStacktrace:\n{sys.exc_info()}"
317
+ print(error_msg, flush=True)
318
+ return error_msg
319
 
320
+ # Gradio Interface
321
+ # demo = gr.Interface(
322
+ # fn=process_obj_file,
323
+ # inputs=gr.File(
324
+ # label="Upload OBJ File",
325
+ # file_types=[".obj"],
326
+ # type="file"
327
+ # ),
328
+ # outputs=gr.Textbox(
329
+ # label="Analysis Results (First 10 Lines)",
330
+ # lines=20,
331
+ # max_lines=30
332
+ # ),
333
+ # title="🔷 OBJ File Analyzer",
334
+ # description="Upload a 3D OBJ file to see the first 10 lines of raw content and mesh analysis",
335
+ # examples=None,
336
+ # cache_examples=False
337
+ # )
338
 
339
  if __name__ == "__main__":
340
+ print("="*60, flush=True)
341
+ print("🚀 Starting OBJ File Analyzer...", flush=True)
342
+ print("="*60, flush=True)
 
343
  load_models()
344
+
345
+
346
  demo = gr.Interface(
347
  fn=rignet_inference,
348
  inputs=[
349
  gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"),
350
+ gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth", info="Joint clustering density (default: 0.04)"),
351
+ gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)", info="Joint filtering threshold (default: 1.0)")
352
  ],
353
  outputs=[
354
  gr.File(label="Download Rig TXT"),
355
  gr.Textbox(label="Status", lines=5)
356
  ],
357
+ title="🎭 RigNet: Neural Rigging for 3D Characters",
358
+ description="""
359
+ Upload a 3D character mesh (OBJ format) to automatically generate skeletal rig and skinning weights.
360
+
361
+ **Recommended:** OBJ files with 1K-5K vertices work best.
362
+ **Processing time:** 1-3 minutes on CPU depending on mesh complexity.
363
+ """,
364
+ article="""
365
+ ### 📚 About the Output
366
+
367
+ The generated `*_rig.txt` file contains:
368
+ - **joints**: 3D positions of skeletal joints
369
+ - **root**: Root joint of the hierarchy
370
+ - **hier**: Parent-child relationships (skeleton hierarchy)
371
+ - **skin**: Skinning weights for each vertex
372
+
373
+ This format can be imported into 3D animation software.
374
+
375
+ **Reference:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://arxiv.org/abs/2005.00559)
376
+ """,
377
  allow_flagging="never"
378
  )
 
379
  demo.launch(
380
  server_name="0.0.0.0",
381
  server_port=7860,
382
+ share=False,
383
  show_error=True,
384
  debug=True
385
+ )