ckc99u commited on
Commit
f9a9164
·
verified ·
1 Parent(s): bfb9e36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -176
app.py CHANGED
@@ -1,33 +1,124 @@
1
  #!/usr/bin/env python3
2
-
3
- import sys
4
- import os
5
- import gradio as gr
6
- import trimesh
7
- import numpy as np
8
  import os
9
  import sys
10
  import tempfile
11
  import shutil
12
  import traceback
13
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import torch
 
15
 
16
  # Add RigNet to Python path
17
  sys.path.insert(0, '/app/RigNet')
18
 
19
  # Import RigNet modules
 
 
 
 
 
 
20
  from quick_start import (
21
- create_single_data,
22
  predict_joints,
23
  predict_skeleton,
24
  predict_skinning,
25
  normalize_obj
26
  )
27
- from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
28
- from models.ROOT_GCN import ROOTNET
29
- from models.PairCls_GCN import PairCls as BONENET
30
- from models.SKINNING import SKINNET
31
 
32
  # Global variables for models
33
  device = torch.device("cpu")
@@ -112,11 +203,12 @@ def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True)
112
  shutil.copy(input_obj_path, mesh_filename)
113
 
114
  print(f"\nProcessing: {base_name}")
 
115
 
116
- # Step 1: Create data
117
  print(" [1/4] Creating input data...")
118
  data, vox, surface_geodesic, translation_normalize, scale_normalize = \
119
- create_single_data(mesh_filename)
120
  data.to(device)
121
 
122
  # Step 2: Predict joints
@@ -162,73 +254,32 @@ def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True)
162
 
163
  def rignet_inference(input_obj, bandwidth, threshold):
164
  """
165
- Gradio inference function with extensive debugging
166
  """
167
  print("\n" + "="*60)
168
- print("🔍 DEBUG: rignet_inference CALLED!")
169
  print(f" input_obj type: {type(input_obj)}")
170
- print(f" input_obj value: {input_obj}")
171
- print(f" bandwidth: {bandwidth}")
172
- print(f" threshold: {threshold}")
173
 
174
- # Check if input is None or empty
175
  if input_obj is None:
176
- msg = "⚠️ Please upload an OBJ file first"
177
- print(f" ERROR: {msg}")
178
- print("="*60 + "\n")
179
- return None, msg
180
 
181
  try:
182
- # Ensure models are loaded
183
  load_models()
184
 
185
- # Extract file path - handle multiple Gradio formats
186
  input_path = None
187
-
188
- # Case 1: File object with .name attribute
189
  if hasattr(input_obj, 'name'):
190
  input_path = input_obj.name
191
- print(f" ✓ Got path from .name: {input_path}")
192
-
193
- # Case 2: Already a string path
194
  elif isinstance(input_obj, str):
195
  input_path = input_obj
196
- print(f" ✓ Already a string path: {input_path}")
197
-
198
- # Case 3: Dictionary with 'name' key
199
- elif isinstance(input_obj, dict):
200
- if 'name' in input_obj:
201
- input_path = input_obj['name']
202
- print(f" ✓ Got path from dict['name']: {input_path}")
203
- else:
204
- print(f" ERROR: Dict without 'name' key. Keys: {input_obj.keys()}")
205
-
206
- # Case 4: Unknown type - debug it
207
- else:
208
- print(f" ERROR: Unknown input type!")
209
- print(f" Attributes: {dir(input_obj)}")
210
- if hasattr(input_obj, '__dict__'):
211
- print(f" __dict__: {input_obj.__dict__}")
212
- msg = f"❌ Unexpected file input type: {type(input_obj)}"
213
- print("="*60 + "\n")
214
- return None, msg
215
-
216
- # Validate file path
217
- if not input_path:
218
- msg = "❌ Could not extract file path from input"
219
- print(f" ERROR: {msg}")
220
- print("="*60 + "\n")
221
- return None, msg
222
 
223
- if not os.path.exists(input_path):
224
- msg = f"❌ File does not exist: {input_path}"
225
- print(f" ERROR: {msg}")
226
- print("="*60 + "\n")
227
- return None, msg
228
 
229
- file_size = os.path.getsize(input_path)
230
- print(f" File validated: {input_path}")
231
- print(f" ✓ File size: {file_size:,} bytes")
232
  print("="*60 + "\n")
233
 
234
  # Process the mesh
@@ -239,147 +290,46 @@ def rignet_inference(input_obj, bandwidth, threshold):
239
  downsample_skinning=True
240
  )
241
 
242
- # Validate output
243
  if not os.path.exists(output_rig_path):
244
- msg = "❌ Output file was not created"
245
- print(f"ERROR: {msg}")
246
- return None, msg
247
 
248
  output_size = os.path.getsize(output_rig_path)
249
  status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes"
250
 
251
- print(f"✓ SUCCESS! Returning output file")
252
  return output_rig_path, status_msg
253
 
254
  except Exception as e:
255
- error_msg = f"❌ Error during processing:\n\n{str(e)}\n\nDetails:\n{traceback.format_exc()}"
256
- print("\n" + "="*60)
257
- print("❌ EXCEPTION CAUGHT:")
258
  print(error_msg)
259
- print("="*60 + "\n")
260
  return None, error_msg
261
- def process_obj_file(file_obj):
262
- """
263
- Process OBJ file and return first 10 lines of analysis results
264
- """
265
- sys.stdout.flush()
266
- print(f"[DEBUG] Processing file: {file_obj.name if file_obj else 'None'}", flush=True)
267
-
268
- if not file_obj:
269
- return "⚠️ No file provided"
270
-
271
- try:
272
- results = []
273
- results.append("="*60)
274
- results.append("OBJ FILE ANALYSIS - First 10 Lines of Results")
275
- results.append("="*60)
276
-
277
- # Read raw OBJ file first 10 lines
278
- results.append("\n📄 RAW OBJ FILE (First 10 Lines):")
279
- results.append("-"*60)
280
- with open(file_obj.name, 'r') as f:
281
- for i, line in enumerate(f):
282
- if i >= 10:
283
- break
284
- results.append(f"Line {i+1}: {line.rstrip()}")
285
-
286
- # Load mesh using trimesh
287
- results.append("\n🔷 MESH ANALYSIS:")
288
- results.append("-"*60)
289
-
290
- mesh = trimesh.load(file_obj.name, force='mesh')
291
-
292
- # Check if it's a Scene or Mesh
293
- if isinstance(mesh, trimesh.Scene):
294
- results.append(f"Type: Scene with {len(mesh.geometry)} geometries")
295
- # Get the first geometry
296
- if len(mesh.geometry) > 0:
297
- first_geom_name = list(mesh.geometry.keys())[0]
298
- mesh = mesh.geometry[first_geom_name]
299
- results.append(f"Using first geometry: {first_geom_name}")
300
-
301
- # Mesh statistics (ensures we don't exceed 10 total result lines)
302
- results.append(f"Vertices: {len(mesh.vertices)}")
303
- results.append(f"Faces: {len(mesh.faces)}")
304
- results.append(f"Is Watertight: {mesh.is_watertight}")
305
- results.append(f"Is Winding Consistent: {mesh.is_winding_consistent}")
306
- results.append(f"Bounds: {mesh.bounds.tolist()}")
307
- results.append(f"Center Mass: {mesh.center_mass.tolist()}")
308
-
309
- # Join results
310
- output = "\n".join(results[:25]) # Limit output
311
-
312
- print("[DEBUG] Processing completed successfully", flush=True)
313
- return output
314
-
315
- except Exception as e:
316
- error_msg = f"❌ Error processing file: {str(e)}\n\nStacktrace:\n{sys.exc_info()}"
317
- print(error_msg, flush=True)
318
- return error_msg
319
 
320
- # Gradio Interface
321
- # demo = gr.Interface(
322
- # fn=process_obj_file,
323
- # inputs=gr.File(
324
- # label="Upload OBJ File",
325
- # file_types=[".obj"],
326
- # type="file"
327
- # ),
328
- # outputs=gr.Textbox(
329
- # label="Analysis Results (First 10 Lines)",
330
- # lines=20,
331
- # max_lines=30
332
- # ),
333
- # title="🔷 OBJ File Analyzer",
334
- # description="Upload a 3D OBJ file to see the first 10 lines of raw content and mesh analysis",
335
- # examples=None,
336
- # cache_examples=False
337
- # )
338
 
339
  if __name__ == "__main__":
340
- print("="*60, flush=True)
341
- print("🚀 Starting OBJ File Analyzer...", flush=True)
342
- print("="*60, flush=True)
 
343
  load_models()
344
-
345
-
346
  demo = gr.Interface(
347
  fn=rignet_inference,
348
  inputs=[
349
  gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"),
350
- gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth", info="Joint clustering density (default: 0.04)"),
351
- gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)", info="Joint filtering threshold (default: 1.0)")
352
  ],
353
  outputs=[
354
  gr.File(label="Download Rig TXT"),
355
  gr.Textbox(label="Status", lines=5)
356
  ],
357
  title="🎭 RigNet: Neural Rigging for 3D Characters",
358
- description="""
359
- Upload a 3D character mesh (OBJ format) to automatically generate skeletal rig and skinning weights.
360
-
361
- **Recommended:** OBJ files with 1K-5K vertices work best.
362
- **Processing time:** 1-3 minutes on CPU depending on mesh complexity.
363
- """,
364
- article="""
365
- ### 📚 About the Output
366
-
367
- The generated `*_rig.txt` file contains:
368
- - **joints**: 3D positions of skeletal joints
369
- - **root**: Root joint of the hierarchy
370
- - **hier**: Parent-child relationships (skeleton hierarchy)
371
- - **skin**: Skinning weights for each vertex
372
-
373
- This format can be imported into 3D animation software.
374
-
375
- **Reference:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://arxiv.org/abs/2005.00559)
376
- """,
377
  allow_flagging="never"
378
  )
 
379
  demo.launch(
380
  server_name="0.0.0.0",
381
  server_port=7860,
382
- share=False,
383
  show_error=True,
384
  debug=True
385
  )
 
1
  #!/usr/bin/env python3
2
+ """
3
+ RigNet Gradio Demo for Hugging Face Spaces (CPU)
4
+ """
 
 
 
5
  import os
6
  import sys
7
  import tempfile
8
  import shutil
9
  import traceback
10
  from pathlib import Path
11
+
12
+ # CRITICAL: Patch quick_start.py BEFORE importing
13
+ # Fix binvox path issue - must happen before RigNet imports
14
+ import_patch_code = '''
15
+ import os
16
+ from sys import platform
17
+
18
+ # Monkey-patch quick_start.create_single_data to fix binvox paths
19
+ original_create_single_data = None
20
+
21
+ def patched_create_single_data(mesh_filename):
22
+ """Patched version that fixes binvox path issues"""
23
+ import open3d as o3d
24
+ import numpy as np
25
+ import torch
26
+ from torch_geometric.data import Data
27
+ from torch_geometric.utils import add_self_loops
28
+ sys.path.insert(0, '/app/RigNet')
29
+ from utils import binvox_rw
30
+ from gen_dataset import get_tpl_edges, get_geo_edges
31
+ from geometric_proc.common_ops import calc_surface_geodesic
32
+ from quick_start import normalize_obj
33
+
34
+ # Load and normalize mesh
35
+ mesh = o3d.io.read_triangle_mesh(mesh_filename)
36
+ mesh.compute_vertex_normals()
37
+ mesh_v = np.asarray(mesh.vertices)
38
+ mesh_vn = np.asarray(mesh.vertex_normals)
39
+ mesh_f = np.asarray(mesh.triangles)
40
+ mesh_v, translation_normalize, scale_normalize = normalize_obj(mesh_v)
41
+
42
+ # Save normalized mesh
43
+ mesh_normalized = o3d.geometry.TriangleMesh(
44
+ vertices=o3d.utility.Vector3dVector(mesh_v),
45
+ triangles=o3d.utility.Vector3iVector(mesh_f)
46
+ )
47
+ normalized_obj = mesh_filename.replace("_remesh.obj", "_normalized.obj")
48
+ o3d.io.write_triangle_mesh(normalized_obj, mesh_normalized)
49
+
50
+ # Prepare data
51
+ v = np.concatenate((mesh_v, mesh_vn), axis=1)
52
+ v = torch.from_numpy(v).float()
53
+
54
+ # Topology edges
55
+ print(" gathering topological edges.")
56
+ tpl_e = get_tpl_edges(mesh_v, mesh_f).T
57
+ tpl_e = torch.from_numpy(tpl_e).long()
58
+ tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
59
+
60
+ # Surface geodesic
61
+ print(" calculating surface geodesic matrix.")
62
+ surface_geodesic = calc_surface_geodesic(mesh)
63
+
64
+ # Geodesic edges
65
+ print(" gathering geodesic edges.")
66
+ geo_e = get_geo_edges(surface_geodesic, mesh_v).T
67
+ geo_e = torch.from_numpy(geo_e).long()
68
+ geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
69
+
70
+ # Batch
71
+ batch = torch.zeros(len(v), dtype=torch.long)
72
+
73
+ # Voxelization - FIX THE PATH ISSUE HERE
74
+ binvox_file = normalized_obj.replace('.obj', '.binvox')
75
+
76
+ if not os.path.exists(binvox_file):
77
+ print(f" Creating voxel file: {binvox_file}")
78
+ # Use full path to binvox
79
+ cmd = f"/usr/local/bin/binvox -d 88 -pb {normalized_obj}"
80
+ print(f" Running: {cmd}")
81
+ ret = os.system(cmd)
82
+ if ret != 0:
83
+ raise RuntimeError(f"binvox failed with return code {ret}")
84
+
85
+ # Verify binvox file was created
86
+ if not os.path.exists(binvox_file):
87
+ raise FileNotFoundError(f"Binvox file not created: {binvox_file}")
88
+
89
+ # Load voxel data
90
+ with open(binvox_file, 'rb') as fvox:
91
+ vox = binvox_rw.read_as_3d_array(fvox)
92
+
93
+ data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e,
94
+ geo_edge_index=geo_e, batch=batch)
95
+
96
+ return data, vox, surface_geodesic, translation_normalize, scale_normalize
97
+ '''
98
+
99
+ # Execute the patch
100
+ exec(import_patch_code)
101
+
102
+ import gradio as gr
103
  import torch
104
+ import numpy as np
105
 
106
  # Add RigNet to Python path
107
  sys.path.insert(0, '/app/RigNet')
108
 
109
  # Import RigNet modules
110
+ from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
111
+ from models.ROOT_GCN import ROOTNET
112
+ from models.PairCls_GCN import PairCls as BONENET
113
+ from models.SKINNING import SKINNET
114
+
115
+ # Import other functions from quick_start (we'll use our patched version)
116
  from quick_start import (
 
117
  predict_joints,
118
  predict_skeleton,
119
  predict_skinning,
120
  normalize_obj
121
  )
 
 
 
 
122
 
123
  # Global variables for models
124
  device = torch.device("cpu")
 
203
  shutil.copy(input_obj_path, mesh_filename)
204
 
205
  print(f"\nProcessing: {base_name}")
206
+ print(f"Working directory: {work_dir}")
207
 
208
+ # Step 1: Create data - USE PATCHED VERSION
209
  print(" [1/4] Creating input data...")
210
  data, vox, surface_geodesic, translation_normalize, scale_normalize = \
211
+ patched_create_single_data(mesh_filename)
212
  data.to(device)
213
 
214
  # Step 2: Predict joints
 
254
 
255
  def rignet_inference(input_obj, bandwidth, threshold):
256
  """
257
+ Gradio inference function
258
  """
259
  print("\n" + "="*60)
260
+ print("🔍 rignet_inference CALLED!")
261
  print(f" input_obj type: {type(input_obj)}")
 
 
 
262
 
 
263
  if input_obj is None:
264
+ return None, "⚠️ Please upload an OBJ file first"
 
 
 
265
 
266
  try:
 
267
  load_models()
268
 
269
+ # Extract file path
270
  input_path = None
 
 
271
  if hasattr(input_obj, 'name'):
272
  input_path = input_obj.name
 
 
 
273
  elif isinstance(input_obj, str):
274
  input_path = input_obj
275
+ elif isinstance(input_obj, dict) and 'name' in input_obj:
276
+ input_path = input_obj['name']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ if not input_path or not os.path.exists(input_path):
279
+ return None, f"❌ Invalid file path: {input_path}"
 
 
 
280
 
281
+ print(f" Processing: {input_path}")
282
+ print(f" File size: {os.path.getsize(input_path):,} bytes")
 
283
  print("="*60 + "\n")
284
 
285
  # Process the mesh
 
290
  downsample_skinning=True
291
  )
292
 
 
293
  if not os.path.exists(output_rig_path):
294
+ return None, "❌ Output file was not created"
 
 
295
 
296
  output_size = os.path.getsize(output_rig_path)
297
  status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes"
298
 
 
299
  return output_rig_path, status_msg
300
 
301
  except Exception as e:
302
+ error_msg = f"❌ Error:\n\n{str(e)}\n\n{traceback.format_exc()}"
 
 
303
  print(error_msg)
 
304
  return None, error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if __name__ == "__main__":
308
+ print("="*60)
309
+ print("RigNet Gradio Demo - Starting...")
310
+ print("="*60)
311
+
312
  load_models()
313
+
 
314
  demo = gr.Interface(
315
  fn=rignet_inference,
316
  inputs=[
317
  gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"),
318
+ gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth"),
319
+ gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)")
320
  ],
321
  outputs=[
322
  gr.File(label="Download Rig TXT"),
323
  gr.Textbox(label="Status", lines=5)
324
  ],
325
  title="🎭 RigNet: Neural Rigging for 3D Characters",
326
+ description="Upload OBJ (1K-5K vertices recommended). Processing takes 1-3 minutes on CPU.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  allow_flagging="never"
328
  )
329
+
330
  demo.launch(
331
  server_name="0.0.0.0",
332
  server_port=7860,
 
333
  show_error=True,
334
  debug=True
335
  )