#!/usr/bin/env python3 import sys import os import gradio as gr import trimesh import numpy as np import os import sys import tempfile import shutil import traceback from pathlib import Path import torch # Add RigNet to Python path sys.path.insert(0, '/app/RigNet') # Import RigNet modules from quick_start import ( create_single_data, predict_joints, predict_skeleton, predict_skinning, normalize_obj ) from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET from models.ROOT_GCN import ROOTNET from models.PairCls_GCN import PairCls as BONENET from models.SKINNING import SKINNET # Global variables for models device = torch.device("cpu") models_loaded = False jointNet = None rootNet = None boneNet = None skinNet = None def load_models(): """Load all RigNet models once at startup""" global jointNet, rootNet, boneNet, skinNet, models_loaded if models_loaded: return print("Loading RigNet models...") checkpoint_dir = '/app/RigNet/checkpoints' # Joint prediction network jointNet = JOINTNET() jointNet.to(device) jointNet.eval() jointNet_checkpoint = torch.load( f'{checkpoint_dir}/gcn_meanshift/model_best.pth.tar', map_location=device ) jointNet.load_state_dict(jointNet_checkpoint['state_dict']) print("✓ Joint prediction network loaded") # Root prediction network rootNet = ROOTNET() rootNet.to(device) rootNet.eval() rootNet_checkpoint = torch.load( f'{checkpoint_dir}/rootnet/model_best.pth.tar', map_location=device ) rootNet.load_state_dict(rootNet_checkpoint['state_dict']) print("✓ Root prediction network loaded") # Bone connection network boneNet = BONENET() boneNet.to(device) boneNet.eval() boneNet_checkpoint = torch.load( f'{checkpoint_dir}/bonenet/model_best.pth.tar', map_location=device ) boneNet.load_state_dict(boneNet_checkpoint['state_dict']) print("✓ Connectivity prediction network loaded") # Skinning network skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True) skinNet_checkpoint = torch.load( f'{checkpoint_dir}/skinnet/model_best.pth.tar', map_location=device ) skinNet.load_state_dict(skinNet_checkpoint['state_dict']) skinNet.to(device) skinNet.eval() print("✓ Skinning prediction network loaded") models_loaded = True print("All models loaded successfully!\n") def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True): """ Process a single mesh through the RigNet pipeline """ global jointNet, rootNet, boneNet, skinNet # Create temporary working directory work_dir = tempfile.mkdtemp(prefix='rignet_') try: # Copy and rename input file to expected format base_name = Path(input_obj_path).stem mesh_filename = os.path.join(work_dir, f'{base_name}_remesh.obj') shutil.copy(input_obj_path, mesh_filename) print(f"\nProcessing: {base_name}") # Step 1: Create data print(" [1/4] Creating input data...") data, vox, surface_geodesic, translation_normalize, scale_normalize = \ create_single_data(mesh_filename) data.to(device) # Step 2: Predict joints print(" [2/4] Predicting joints...") data = predict_joints( data, vox, jointNet, threshold, bandwidth=bandwidth, mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj") ) data.to(device) # Step 3: Predict skeleton structure print(" [3/4] Predicting skeleton connectivity...") pred_skeleton = predict_skeleton( data, vox, rootNet, boneNet, mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj") ) # Step 4: Predict skinning weights print(" [4/4] Predicting skinning weights...") pred_rig = predict_skinning( data, pred_skeleton, skinNet, surface_geodesic, mesh_filename.replace("_remesh.obj", "_normalized.obj"), subsampling=downsample_skinning ) # Reverse normalization pred_rig.normalize(scale_normalize, -translation_normalize) # Save result output_rig_path = os.path.join(work_dir, f'{base_name}_rig.txt') pred_rig.save(output_rig_path) print(f"✓ Successfully generated rig: {base_name}_rig.txt\n") return output_rig_path except Exception as e: print(f"ERROR in process_mesh: {str(e)}") traceback.print_exc() raise e def rignet_inference(input_obj, bandwidth, threshold): """ Gradio inference function with extensive debugging """ print("\n" + "="*60) print("🔍 DEBUG: rignet_inference CALLED!") print(f" input_obj type: {type(input_obj)}") print(f" input_obj value: {input_obj}") print(f" bandwidth: {bandwidth}") print(f" threshold: {threshold}") # Check if input is None or empty if input_obj is None: msg = "⚠️ Please upload an OBJ file first" print(f" ERROR: {msg}") print("="*60 + "\n") return None, msg try: # Ensure models are loaded load_models() # Extract file path - handle multiple Gradio formats input_path = None # Case 1: File object with .name attribute if hasattr(input_obj, 'name'): input_path = input_obj.name print(f" ✓ Got path from .name: {input_path}") # Case 2: Already a string path elif isinstance(input_obj, str): input_path = input_obj print(f" ✓ Already a string path: {input_path}") # Case 3: Dictionary with 'name' key elif isinstance(input_obj, dict): if 'name' in input_obj: input_path = input_obj['name'] print(f" ✓ Got path from dict['name']: {input_path}") else: print(f" ERROR: Dict without 'name' key. Keys: {input_obj.keys()}") # Case 4: Unknown type - debug it else: print(f" ERROR: Unknown input type!") print(f" Attributes: {dir(input_obj)}") if hasattr(input_obj, '__dict__'): print(f" __dict__: {input_obj.__dict__}") msg = f"❌ Unexpected file input type: {type(input_obj)}" print("="*60 + "\n") return None, msg # Validate file path if not input_path: msg = "❌ Could not extract file path from input" print(f" ERROR: {msg}") print("="*60 + "\n") return None, msg if not os.path.exists(input_path): msg = f"❌ File does not exist: {input_path}" print(f" ERROR: {msg}") print("="*60 + "\n") return None, msg file_size = os.path.getsize(input_path) print(f" ✓ File validated: {input_path}") print(f" ✓ File size: {file_size:,} bytes") print("="*60 + "\n") # Process the mesh output_rig_path = process_mesh( input_path, bandwidth=bandwidth, threshold=threshold * 1e-5, downsample_skinning=True ) # Validate output if not os.path.exists(output_rig_path): msg = "❌ Output file was not created" print(f"ERROR: {msg}") return None, msg output_size = os.path.getsize(output_rig_path) status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes" print(f"✓ SUCCESS! Returning output file") return output_rig_path, status_msg except Exception as e: error_msg = f"❌ Error during processing:\n\n{str(e)}\n\nDetails:\n{traceback.format_exc()}" print("\n" + "="*60) print("❌ EXCEPTION CAUGHT:") print(error_msg) print("="*60 + "\n") return None, error_msg def process_obj_file(file_obj): """ Process OBJ file and return first 10 lines of analysis results """ sys.stdout.flush() print(f"[DEBUG] Processing file: {file_obj.name if file_obj else 'None'}", flush=True) if not file_obj: return "⚠️ No file provided" try: results = [] results.append("="*60) results.append("OBJ FILE ANALYSIS - First 10 Lines of Results") results.append("="*60) # Read raw OBJ file first 10 lines results.append("\n📄 RAW OBJ FILE (First 10 Lines):") results.append("-"*60) with open(file_obj.name, 'r') as f: for i, line in enumerate(f): if i >= 10: break results.append(f"Line {i+1}: {line.rstrip()}") # Load mesh using trimesh results.append("\n🔷 MESH ANALYSIS:") results.append("-"*60) mesh = trimesh.load(file_obj.name, force='mesh') # Check if it's a Scene or Mesh if isinstance(mesh, trimesh.Scene): results.append(f"Type: Scene with {len(mesh.geometry)} geometries") # Get the first geometry if len(mesh.geometry) > 0: first_geom_name = list(mesh.geometry.keys())[0] mesh = mesh.geometry[first_geom_name] results.append(f"Using first geometry: {first_geom_name}") # Mesh statistics (ensures we don't exceed 10 total result lines) results.append(f"Vertices: {len(mesh.vertices)}") results.append(f"Faces: {len(mesh.faces)}") results.append(f"Is Watertight: {mesh.is_watertight}") results.append(f"Is Winding Consistent: {mesh.is_winding_consistent}") results.append(f"Bounds: {mesh.bounds.tolist()}") results.append(f"Center Mass: {mesh.center_mass.tolist()}") # Join results output = "\n".join(results[:25]) # Limit output print("[DEBUG] Processing completed successfully", flush=True) return output except Exception as e: error_msg = f"❌ Error processing file: {str(e)}\n\nStacktrace:\n{sys.exc_info()}" print(error_msg, flush=True) return error_msg # Gradio Interface # demo = gr.Interface( # fn=process_obj_file, # inputs=gr.File( # label="Upload OBJ File", # file_types=[".obj"], # type="file" # ), # outputs=gr.Textbox( # label="Analysis Results (First 10 Lines)", # lines=20, # max_lines=30 # ), # title="🔷 OBJ File Analyzer", # description="Upload a 3D OBJ file to see the first 10 lines of raw content and mesh analysis", # examples=None, # cache_examples=False # ) if __name__ == "__main__": print("="*60, flush=True) print("🚀 Starting OBJ File Analyzer...", flush=True) print("="*60, flush=True) load_models() demo = gr.Interface( fn=rignet_inference, inputs=[ gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"), gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth", info="Joint clustering density (default: 0.04)"), gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)", info="Joint filtering threshold (default: 1.0)") ], outputs=[ gr.File(label="Download Rig TXT"), gr.Textbox(label="Status", lines=5) ], title="🎭 RigNet: Neural Rigging for 3D Characters", description=""" Upload a 3D character mesh (OBJ format) to automatically generate skeletal rig and skinning weights. **Recommended:** OBJ files with 1K-5K vertices work best. **Processing time:** 1-3 minutes on CPU depending on mesh complexity. """, article=""" ### 📚 About the Output The generated `*_rig.txt` file contains: - **joints**: 3D positions of skeletal joints - **root**: Root joint of the hierarchy - **hier**: Parent-child relationships (skeleton hierarchy) - **skin**: Skinning weights for each vertex This format can be imported into 3D animation software. **Reference:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://arxiv.org/abs/2005.00559) """, allow_flagging="never" ) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, debug=True )