|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
import gradio as gr |
|
|
import trimesh |
|
|
import numpy as np |
|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import shutil |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
import torch |
|
|
|
|
|
|
|
|
sys.path.insert(0, '/app/RigNet') |
|
|
|
|
|
|
|
|
from quick_start import ( |
|
|
create_single_data, |
|
|
predict_joints, |
|
|
predict_skeleton, |
|
|
predict_skinning, |
|
|
normalize_obj |
|
|
) |
|
|
from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET |
|
|
from models.ROOT_GCN import ROOTNET |
|
|
from models.PairCls_GCN import PairCls as BONENET |
|
|
from models.SKINNING import SKINNET |
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
models_loaded = False |
|
|
jointNet = None |
|
|
rootNet = None |
|
|
boneNet = None |
|
|
skinNet = None |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
"""Load all RigNet models once at startup""" |
|
|
global jointNet, rootNet, boneNet, skinNet, models_loaded |
|
|
|
|
|
if models_loaded: |
|
|
return |
|
|
|
|
|
print("Loading RigNet models...") |
|
|
checkpoint_dir = '/app/RigNet/checkpoints' |
|
|
|
|
|
|
|
|
jointNet = JOINTNET() |
|
|
jointNet.to(device) |
|
|
jointNet.eval() |
|
|
jointNet_checkpoint = torch.load( |
|
|
f'{checkpoint_dir}/gcn_meanshift/model_best.pth.tar', |
|
|
map_location=device |
|
|
) |
|
|
jointNet.load_state_dict(jointNet_checkpoint['state_dict']) |
|
|
print("✓ Joint prediction network loaded") |
|
|
|
|
|
|
|
|
rootNet = ROOTNET() |
|
|
rootNet.to(device) |
|
|
rootNet.eval() |
|
|
rootNet_checkpoint = torch.load( |
|
|
f'{checkpoint_dir}/rootnet/model_best.pth.tar', |
|
|
map_location=device |
|
|
) |
|
|
rootNet.load_state_dict(rootNet_checkpoint['state_dict']) |
|
|
print("✓ Root prediction network loaded") |
|
|
|
|
|
|
|
|
boneNet = BONENET() |
|
|
boneNet.to(device) |
|
|
boneNet.eval() |
|
|
boneNet_checkpoint = torch.load( |
|
|
f'{checkpoint_dir}/bonenet/model_best.pth.tar', |
|
|
map_location=device |
|
|
) |
|
|
boneNet.load_state_dict(boneNet_checkpoint['state_dict']) |
|
|
print("✓ Connectivity prediction network loaded") |
|
|
|
|
|
|
|
|
skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True) |
|
|
skinNet_checkpoint = torch.load( |
|
|
f'{checkpoint_dir}/skinnet/model_best.pth.tar', |
|
|
map_location=device |
|
|
) |
|
|
skinNet.load_state_dict(skinNet_checkpoint['state_dict']) |
|
|
skinNet.to(device) |
|
|
skinNet.eval() |
|
|
print("✓ Skinning prediction network loaded") |
|
|
|
|
|
models_loaded = True |
|
|
print("All models loaded successfully!\n") |
|
|
|
|
|
|
|
|
def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True): |
|
|
""" |
|
|
Process a single mesh through the RigNet pipeline |
|
|
""" |
|
|
global jointNet, rootNet, boneNet, skinNet |
|
|
|
|
|
|
|
|
work_dir = tempfile.mkdtemp(prefix='rignet_') |
|
|
|
|
|
try: |
|
|
|
|
|
base_name = Path(input_obj_path).stem |
|
|
mesh_filename = os.path.join(work_dir, f'{base_name}_remesh.obj') |
|
|
shutil.copy(input_obj_path, mesh_filename) |
|
|
|
|
|
print(f"\nProcessing: {base_name}") |
|
|
|
|
|
|
|
|
print(" [1/4] Creating input data...") |
|
|
data, vox, surface_geodesic, translation_normalize, scale_normalize = \ |
|
|
create_single_data(mesh_filename) |
|
|
data.to(device) |
|
|
|
|
|
|
|
|
print(" [2/4] Predicting joints...") |
|
|
data = predict_joints( |
|
|
data, vox, jointNet, threshold, |
|
|
bandwidth=bandwidth, |
|
|
mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj") |
|
|
) |
|
|
data.to(device) |
|
|
|
|
|
|
|
|
print(" [3/4] Predicting skeleton connectivity...") |
|
|
pred_skeleton = predict_skeleton( |
|
|
data, vox, rootNet, boneNet, |
|
|
mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj") |
|
|
) |
|
|
|
|
|
|
|
|
print(" [4/4] Predicting skinning weights...") |
|
|
pred_rig = predict_skinning( |
|
|
data, pred_skeleton, skinNet, surface_geodesic, |
|
|
mesh_filename.replace("_remesh.obj", "_normalized.obj"), |
|
|
subsampling=downsample_skinning |
|
|
) |
|
|
|
|
|
|
|
|
pred_rig.normalize(scale_normalize, -translation_normalize) |
|
|
|
|
|
|
|
|
output_rig_path = os.path.join(work_dir, f'{base_name}_rig.txt') |
|
|
pred_rig.save(output_rig_path) |
|
|
|
|
|
print(f"✓ Successfully generated rig: {base_name}_rig.txt\n") |
|
|
|
|
|
return output_rig_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR in process_mesh: {str(e)}") |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
|
|
|
def rignet_inference(input_obj, bandwidth, threshold): |
|
|
""" |
|
|
Gradio inference function with extensive debugging |
|
|
""" |
|
|
print("\n" + "="*60) |
|
|
print("🔠DEBUG: rignet_inference CALLED!") |
|
|
print(f" input_obj type: {type(input_obj)}") |
|
|
print(f" input_obj value: {input_obj}") |
|
|
print(f" bandwidth: {bandwidth}") |
|
|
print(f" threshold: {threshold}") |
|
|
|
|
|
|
|
|
if input_obj is None: |
|
|
msg = "âš ï¸ Please upload an OBJ file first" |
|
|
print(f" ERROR: {msg}") |
|
|
print("="*60 + "\n") |
|
|
return None, msg |
|
|
|
|
|
try: |
|
|
|
|
|
load_models() |
|
|
|
|
|
|
|
|
input_path = None |
|
|
|
|
|
|
|
|
if hasattr(input_obj, 'name'): |
|
|
input_path = input_obj.name |
|
|
print(f" ✓ Got path from .name: {input_path}") |
|
|
|
|
|
|
|
|
elif isinstance(input_obj, str): |
|
|
input_path = input_obj |
|
|
print(f" ✓ Already a string path: {input_path}") |
|
|
|
|
|
|
|
|
elif isinstance(input_obj, dict): |
|
|
if 'name' in input_obj: |
|
|
input_path = input_obj['name'] |
|
|
print(f" ✓ Got path from dict['name']: {input_path}") |
|
|
else: |
|
|
print(f" ERROR: Dict without 'name' key. Keys: {input_obj.keys()}") |
|
|
|
|
|
|
|
|
else: |
|
|
print(f" ERROR: Unknown input type!") |
|
|
print(f" Attributes: {dir(input_obj)}") |
|
|
if hasattr(input_obj, '__dict__'): |
|
|
print(f" __dict__: {input_obj.__dict__}") |
|
|
msg = f"⌠Unexpected file input type: {type(input_obj)}" |
|
|
print("="*60 + "\n") |
|
|
return None, msg |
|
|
|
|
|
|
|
|
if not input_path: |
|
|
msg = "⌠Could not extract file path from input" |
|
|
print(f" ERROR: {msg}") |
|
|
print("="*60 + "\n") |
|
|
return None, msg |
|
|
|
|
|
if not os.path.exists(input_path): |
|
|
msg = f"⌠File does not exist: {input_path}" |
|
|
print(f" ERROR: {msg}") |
|
|
print("="*60 + "\n") |
|
|
return None, msg |
|
|
|
|
|
file_size = os.path.getsize(input_path) |
|
|
print(f" ✓ File validated: {input_path}") |
|
|
print(f" ✓ File size: {file_size:,} bytes") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
|
|
|
output_rig_path = process_mesh( |
|
|
input_path, |
|
|
bandwidth=bandwidth, |
|
|
threshold=threshold * 1e-5, |
|
|
downsample_skinning=True |
|
|
) |
|
|
|
|
|
|
|
|
if not os.path.exists(output_rig_path): |
|
|
msg = "⌠Output file was not created" |
|
|
print(f"ERROR: {msg}") |
|
|
return None, msg |
|
|
|
|
|
output_size = os.path.getsize(output_rig_path) |
|
|
status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes" |
|
|
|
|
|
print(f"✓ SUCCESS! Returning output file") |
|
|
return output_rig_path, status_msg |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"⌠Error during processing:\n\n{str(e)}\n\nDetails:\n{traceback.format_exc()}" |
|
|
print("\n" + "="*60) |
|
|
print("⌠EXCEPTION CAUGHT:") |
|
|
print(error_msg) |
|
|
print("="*60 + "\n") |
|
|
return None, error_msg |
|
|
def process_obj_file(file_obj): |
|
|
""" |
|
|
Process OBJ file and return first 10 lines of analysis results |
|
|
""" |
|
|
sys.stdout.flush() |
|
|
print(f"[DEBUG] Processing file: {file_obj.name if file_obj else 'None'}", flush=True) |
|
|
|
|
|
if not file_obj: |
|
|
return "âš ï¸ No file provided" |
|
|
|
|
|
try: |
|
|
results = [] |
|
|
results.append("="*60) |
|
|
results.append("OBJ FILE ANALYSIS - First 10 Lines of Results") |
|
|
results.append("="*60) |
|
|
|
|
|
|
|
|
results.append("\n📄 RAW OBJ FILE (First 10 Lines):") |
|
|
results.append("-"*60) |
|
|
with open(file_obj.name, 'r') as f: |
|
|
for i, line in enumerate(f): |
|
|
if i >= 10: |
|
|
break |
|
|
results.append(f"Line {i+1}: {line.rstrip()}") |
|
|
|
|
|
|
|
|
results.append("\n🔷 MESH ANALYSIS:") |
|
|
results.append("-"*60) |
|
|
|
|
|
mesh = trimesh.load(file_obj.name, force='mesh') |
|
|
|
|
|
|
|
|
if isinstance(mesh, trimesh.Scene): |
|
|
results.append(f"Type: Scene with {len(mesh.geometry)} geometries") |
|
|
|
|
|
if len(mesh.geometry) > 0: |
|
|
first_geom_name = list(mesh.geometry.keys())[0] |
|
|
mesh = mesh.geometry[first_geom_name] |
|
|
results.append(f"Using first geometry: {first_geom_name}") |
|
|
|
|
|
|
|
|
results.append(f"Vertices: {len(mesh.vertices)}") |
|
|
results.append(f"Faces: {len(mesh.faces)}") |
|
|
results.append(f"Is Watertight: {mesh.is_watertight}") |
|
|
results.append(f"Is Winding Consistent: {mesh.is_winding_consistent}") |
|
|
results.append(f"Bounds: {mesh.bounds.tolist()}") |
|
|
results.append(f"Center Mass: {mesh.center_mass.tolist()}") |
|
|
|
|
|
|
|
|
output = "\n".join(results[:25]) |
|
|
|
|
|
print("[DEBUG] Processing completed successfully", flush=True) |
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"⌠Error processing file: {str(e)}\n\nStacktrace:\n{sys.exc_info()}" |
|
|
print(error_msg, flush=True) |
|
|
return error_msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("="*60, flush=True) |
|
|
print("🚀 Starting OBJ File Analyzer...", flush=True) |
|
|
print("="*60, flush=True) |
|
|
load_models() |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=rignet_inference, |
|
|
inputs=[ |
|
|
gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"), |
|
|
gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth", info="Joint clustering density (default: 0.04)"), |
|
|
gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10â»âµ)", info="Joint filtering threshold (default: 1.0)") |
|
|
], |
|
|
outputs=[ |
|
|
gr.File(label="Download Rig TXT"), |
|
|
gr.Textbox(label="Status", lines=5) |
|
|
], |
|
|
title="🎠RigNet: Neural Rigging for 3D Characters", |
|
|
description=""" |
|
|
Upload a 3D character mesh (OBJ format) to automatically generate skeletal rig and skinning weights. |
|
|
|
|
|
**Recommended:** OBJ files with 1K-5K vertices work best. |
|
|
**Processing time:** 1-3 minutes on CPU depending on mesh complexity. |
|
|
""", |
|
|
article=""" |
|
|
### 📚 About the Output |
|
|
|
|
|
The generated `*_rig.txt` file contains: |
|
|
- **joints**: 3D positions of skeletal joints |
|
|
- **root**: Root joint of the hierarchy |
|
|
- **hier**: Parent-child relationships (skeleton hierarchy) |
|
|
- **skin**: Skinning weights for each vertex |
|
|
|
|
|
This format can be imported into 3D animation software. |
|
|
|
|
|
**Reference:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://arxiv.org/abs/2005.00559) |
|
|
""", |
|
|
allow_flagging="never" |
|
|
) |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True, |
|
|
debug=True |
|
|
) |