File size: 12,939 Bytes
91ba680
fab76dd
 
 
 
 
 
36003a9
 
 
 
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
fab76dd
 
 
 
36003a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
fab76dd
36003a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
 
 
 
 
 
 
 
c0eb7ea
36003a9
 
fab76dd
36003a9
 
fab76dd
36003a9
fab76dd
 
 
36003a9
fab76dd
36003a9
fab76dd
 
 
 
36003a9
 
fab76dd
36003a9
 
fab76dd
36003a9
fab76dd
 
36003a9
 
fab76dd
 
 
36003a9
 
fab76dd
36003a9
fab76dd
 
 
 
 
 
 
36003a9
fab76dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36003a9
 
 
 
 
 
 
 
 
 
fab76dd
36003a9
fab76dd
 
 
36003a9
 
fab76dd
36003a9
fab76dd
36003a9
 
 
fab76dd
 
 
36003a9
fab76dd
36003a9
fab76dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d429735
fab76dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38ff0f
9db0ebd
fab76dd
 
 
0fb16ad
fab76dd
 
0fb16ad
 
 
 
fab76dd
 
0fb16ad
 
 
 
 
fab76dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fb16ad
 
91ba680
 
 
fab76dd
2addbd2
0e45a5a
fab76dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
#!/usr/bin/env python3

import sys
import os
import gradio as gr
import trimesh
import numpy as np
import os
import sys
import tempfile
import shutil
import traceback
from pathlib import Path
import torch

# Add RigNet to Python path
sys.path.insert(0, '/app/RigNet')

# Import RigNet modules
from quick_start import (
    create_single_data,
    predict_joints,
    predict_skeleton,
    predict_skinning,
    normalize_obj
)
from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
from models.ROOT_GCN import ROOTNET
from models.PairCls_GCN import PairCls as BONENET
from models.SKINNING import SKINNET

# Global variables for models
device = torch.device("cpu")
models_loaded = False
jointNet = None
rootNet = None
boneNet = None
skinNet = None


def load_models():
    """Load all RigNet models once at startup"""
    global jointNet, rootNet, boneNet, skinNet, models_loaded
    
    if models_loaded:
        return
    
    print("Loading RigNet models...")
    checkpoint_dir = '/app/RigNet/checkpoints'
    
    # Joint prediction network
    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load(
        f'{checkpoint_dir}/gcn_meanshift/model_best.pth.tar',
        map_location=device
    )
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("✓ Joint prediction network loaded")
    
    # Root prediction network
    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load(
        f'{checkpoint_dir}/rootnet/model_best.pth.tar',
        map_location=device
    )
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("✓ Root prediction network loaded")
    
    # Bone connection network
    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load(
        f'{checkpoint_dir}/bonenet/model_best.pth.tar',
        map_location=device
    )
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("✓ Connectivity prediction network loaded")
    
    # Skinning network
    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load(
        f'{checkpoint_dir}/skinnet/model_best.pth.tar',
        map_location=device
    )
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("✓ Skinning prediction network loaded")
    
    models_loaded = True
    print("All models loaded successfully!\n")


def process_mesh(input_obj_path, bandwidth, threshold, downsample_skinning=True):
    """
    Process a single mesh through the RigNet pipeline
    """
    global jointNet, rootNet, boneNet, skinNet
    
    # Create temporary working directory
    work_dir = tempfile.mkdtemp(prefix='rignet_')
    
    try:
        # Copy and rename input file to expected format
        base_name = Path(input_obj_path).stem
        mesh_filename = os.path.join(work_dir, f'{base_name}_remesh.obj')
        shutil.copy(input_obj_path, mesh_filename)
        
        print(f"\nProcessing: {base_name}")
        
        # Step 1: Create data
        print("  [1/4] Creating input data...")
        data, vox, surface_geodesic, translation_normalize, scale_normalize = \
            create_single_data(mesh_filename)
        data.to(device)
        
        # Step 2: Predict joints
        print("  [2/4] Predicting joints...")
        data = predict_joints(
            data, vox, jointNet, threshold,
            bandwidth=bandwidth,
            mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj")
        )
        data.to(device)
        
        # Step 3: Predict skeleton structure
        print("  [3/4] Predicting skeleton connectivity...")
        pred_skeleton = predict_skeleton(
            data, vox, rootNet, boneNet,
            mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj")
        )
        
        # Step 4: Predict skinning weights
        print("  [4/4] Predicting skinning weights...")
        pred_rig = predict_skinning(
            data, pred_skeleton, skinNet, surface_geodesic,
            mesh_filename.replace("_remesh.obj", "_normalized.obj"),
            subsampling=downsample_skinning
        )
        
        # Reverse normalization
        pred_rig.normalize(scale_normalize, -translation_normalize)
        
        # Save result
        output_rig_path = os.path.join(work_dir, f'{base_name}_rig.txt')
        pred_rig.save(output_rig_path)
        
        print(f"✓ Successfully generated rig: {base_name}_rig.txt\n")
        
        return output_rig_path
        
    except Exception as e:
        print(f"ERROR in process_mesh: {str(e)}")
        traceback.print_exc()
        raise e


def rignet_inference(input_obj, bandwidth, threshold):
    """
    Gradio inference function with extensive debugging
    """
    print("\n" + "="*60)
    print("🔍 DEBUG: rignet_inference CALLED!")
    print(f"   input_obj type: {type(input_obj)}")
    print(f"   input_obj value: {input_obj}")
    print(f"   bandwidth: {bandwidth}")
    print(f"   threshold: {threshold}")
    
    # Check if input is None or empty
    if input_obj is None:
        msg = "⚠️ Please upload an OBJ file first"
        print(f"   ERROR: {msg}")
        print("="*60 + "\n")
        return None, msg
    
    try:
        # Ensure models are loaded
        load_models()
        
        # Extract file path - handle multiple Gradio formats
        input_path = None
        
        # Case 1: File object with .name attribute
        if hasattr(input_obj, 'name'):
            input_path = input_obj.name
            print(f"   ✓ Got path from .name: {input_path}")
        
        # Case 2: Already a string path
        elif isinstance(input_obj, str):
            input_path = input_obj
            print(f"   ✓ Already a string path: {input_path}")
        
        # Case 3: Dictionary with 'name' key
        elif isinstance(input_obj, dict):
            if 'name' in input_obj:
                input_path = input_obj['name']
                print(f"   ✓ Got path from dict['name']: {input_path}")
            else:
                print(f"   ERROR: Dict without 'name' key. Keys: {input_obj.keys()}")
        
        # Case 4: Unknown type - debug it
        else:
            print(f"   ERROR: Unknown input type!")
            print(f"   Attributes: {dir(input_obj)}")
            if hasattr(input_obj, '__dict__'):
                print(f"   __dict__: {input_obj.__dict__}")
            msg = f"❌ Unexpected file input type: {type(input_obj)}"
            print("="*60 + "\n")
            return None, msg
        
        # Validate file path
        if not input_path:
            msg = "❌ Could not extract file path from input"
            print(f"   ERROR: {msg}")
            print("="*60 + "\n")
            return None, msg
        
        if not os.path.exists(input_path):
            msg = f"❌ File does not exist: {input_path}"
            print(f"   ERROR: {msg}")
            print("="*60 + "\n")
            return None, msg
        
        file_size = os.path.getsize(input_path)
        print(f"   ✓ File validated: {input_path}")
        print(f"   ✓ File size: {file_size:,} bytes")
        print("="*60 + "\n")
        
        # Process the mesh
        output_rig_path = process_mesh(
            input_path,
            bandwidth=bandwidth,
            threshold=threshold * 1e-5,
            downsample_skinning=True
        )
        
        # Validate output
        if not os.path.exists(output_rig_path):
            msg = "❌ Output file was not created"
            print(f"ERROR: {msg}")
            return None, msg
        
        output_size = os.path.getsize(output_rig_path)
        status_msg = f"✅ Rigging completed!\n\nFile: {os.path.basename(output_rig_path)}\nSize: {output_size:,} bytes"
        
        print(f"✓ SUCCESS! Returning output file")
        return output_rig_path, status_msg
        
    except Exception as e:
        error_msg = f"❌ Error during processing:\n\n{str(e)}\n\nDetails:\n{traceback.format_exc()}"
        print("\n" + "="*60)
        print("❌ EXCEPTION CAUGHT:")
        print(error_msg)
        print("="*60 + "\n")
        return None, error_msg
def process_obj_file(file_obj):
    """
    Process OBJ file and return first 10 lines of analysis results
    """
    sys.stdout.flush()
    print(f"[DEBUG] Processing file: {file_obj.name if file_obj else 'None'}", flush=True)
    
    if not file_obj:
        return "⚠️ No file provided"
    
    try:
        results = []
        results.append("="*60)
        results.append("OBJ FILE ANALYSIS - First 10 Lines of Results")
        results.append("="*60)
        
        # Read raw OBJ file first 10 lines
        results.append("\n📄 RAW OBJ FILE (First 10 Lines):")
        results.append("-"*60)
        with open(file_obj.name, 'r') as f:
            for i, line in enumerate(f):
                if i >= 10:
                    break
                results.append(f"Line {i+1}: {line.rstrip()}")
        
        # Load mesh using trimesh
        results.append("\n🔷 MESH ANALYSIS:")
        results.append("-"*60)
        
        mesh = trimesh.load(file_obj.name, force='mesh')
        
        # Check if it's a Scene or Mesh
        if isinstance(mesh, trimesh.Scene):
            results.append(f"Type: Scene with {len(mesh.geometry)} geometries")
            # Get the first geometry
            if len(mesh.geometry) > 0:
                first_geom_name = list(mesh.geometry.keys())[0]
                mesh = mesh.geometry[first_geom_name]
                results.append(f"Using first geometry: {first_geom_name}")
        
        # Mesh statistics (ensures we don't exceed 10 total result lines)
        results.append(f"Vertices: {len(mesh.vertices)}")
        results.append(f"Faces: {len(mesh.faces)}")
        results.append(f"Is Watertight: {mesh.is_watertight}")
        results.append(f"Is Winding Consistent: {mesh.is_winding_consistent}")
        results.append(f"Bounds: {mesh.bounds.tolist()}")
        results.append(f"Center Mass: {mesh.center_mass.tolist()}")
        
        # Join results
        output = "\n".join(results[:25])  # Limit output
        
        print("[DEBUG] Processing completed successfully", flush=True)
        return output
        
    except Exception as e:
        error_msg = f"❌ Error processing file: {str(e)}\n\nStacktrace:\n{sys.exc_info()}"
        print(error_msg, flush=True)
        return error_msg

# Gradio Interface
# demo = gr.Interface(
#     fn=process_obj_file,
#     inputs=gr.File(
#         label="Upload OBJ File",
#         file_types=[".obj"],
#         type="file"
#     ),
#     outputs=gr.Textbox(
#         label="Analysis Results (First 10 Lines)",
#         lines=20,
#         max_lines=30
#     ),
#     title="🔷 OBJ File Analyzer",
#     description="Upload a 3D OBJ file to see the first 10 lines of raw content and mesh analysis",
#     examples=None,
#     cache_examples=False
# )

if __name__ == "__main__":
    print("="*60, flush=True)
    print("🚀 Starting OBJ File Analyzer...", flush=True)
    print("="*60, flush=True)
    load_models()


    demo = gr.Interface(
        fn=rignet_inference,
        inputs=[
            gr.File(label="Upload OBJ File", file_types=[".obj"], type="file"),
            gr.Slider(0.02, 0.08, value=0.04, step=0.001, label="Bandwidth", info="Joint clustering density (default: 0.04)"),
            gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Threshold (×10⁻⁵)", info="Joint filtering threshold (default: 1.0)")
        ],
        outputs=[
            gr.File(label="Download Rig TXT"),
            gr.Textbox(label="Status", lines=5)
        ],
        title="🎭 RigNet: Neural Rigging for 3D Characters",
        description="""
Upload a 3D character mesh (OBJ format) to automatically generate skeletal rig and skinning weights.

**Recommended:** OBJ files with 1K-5K vertices work best.
**Processing time:** 1-3 minutes on CPU depending on mesh complexity.
        """,
        article="""
### 📚 About the Output

The generated `*_rig.txt` file contains:
- **joints**: 3D positions of skeletal joints
- **root**: Root joint of the hierarchy  
- **hier**: Parent-child relationships (skeleton hierarchy)
- **skin**: Skinning weights for each vertex

This format can be imported into 3D animation software.

**Reference:** [RigNet: Neural Rigging for Articulated Characters (SIGGRAPH 2020)](https://arxiv.org/abs/2005.00559)
        """,
        allow_flagging="never"
    )
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
        debug=True
    )