ckc99u commited on
Commit
790a803
·
verified ·
1 Parent(s): 94574e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import platform
5
+ import torch
6
+ import numpy as np
7
+ import open3d as o3d
8
+ import gradio as gr
9
+
10
+ # --- 1. Setup Environment & Paths ---
11
+ RIGNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "RigNet")
12
+ if RIGNET_PATH not in sys.path:
13
+ sys.path.append(RIGNET_PATH)
14
+
15
+ # Ensure binvox is executable
16
+ BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
17
+ BINVOX_DEST = "binvox.exe" if platform.system() == "Windows" else "binvox"
18
+ if platform.system() == "Windows":
19
+ BINVOX_SRC += ".exe"
20
+
21
+ if os.path.exists(BINVOX_SRC):
22
+ shutil.copy(BINVOX_SRC, BINVOX_DEST)
23
+ if platform.system() != "Windows":
24
+ os.system(f"chmod +x {BINVOX_DEST}")
25
+ else:
26
+ print(f"Warning: binvox not found at {BINVOX_SRC}. Inference may fail.")
27
+
28
+ # --- 2. Import RigNet Modules ---
29
+ try:
30
+ from quick_start import (
31
+ create_single_data, predict_joints, predict_skeleton,
32
+ predict_skinning, tranfer_to_ori_mesh
33
+ )
34
+ from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
35
+ from models.ROOT_GCN import ROOTNET
36
+ from models.PairCls_GCN import PairCls as BONENET
37
+ from models.SKINNING import SKINNET
38
+ except ImportError as e:
39
+ print(f"Error importing RigNet: {e}")
40
+
41
+ # --- 3. Load Models ---
42
+ device = torch.device("cpu") # Force CPU for this demo environment
43
+ print(f"Loading RigNet models on {device}...")
44
+
45
+ def load_checkpoint(model, filename):
46
+ filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
47
+ checkpoint = torch.load(filepath, map_location=device)
48
+ model.load_state_dict(checkpoint['state_dict'])
49
+ return model
50
+
51
+ # Initialize models (Global)
52
+ jointNet = JOINTNET().to(device); jointNet.eval()
53
+ load_checkpoint(jointNet, 'gcn_meanshift/model_best.pth.tar')
54
+
55
+ rootNet = ROOTNET().to(device); rootNet.eval()
56
+ load_checkpoint(rootNet, 'rootnet/model_best.pth.tar')
57
+
58
+ boneNet = BONENET().to(device); boneNet.eval()
59
+ load_checkpoint(boneNet, 'bonenet/model_best.pth.tar')
60
+
61
+ skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True).to(device); skinNet.eval()
62
+ load_checkpoint(skinNet, 'skinnet/model_best.pth.tar')
63
+
64
+ print("Models loaded.")
65
+
66
+ # --- 4. Inference Pipeline ---
67
+ def rignet_inference(input_mesh_path):
68
+ if not input_mesh_path: return None
69
+
70
+ working_dir = os.path.dirname(input_mesh_path)
71
+ base_name = os.path.basename(input_mesh_path).replace(".obj", "")
72
+ mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
73
+
74
+ print(f"Processing: {input_mesh_path}")
75
+ try:
76
+ # 1. Simplify
77
+ mesh_ori = o3d.io.read_triangle_mesh(input_mesh_path)
78
+ if len(np.asarray(mesh_ori.vertices)) == 0: raise ValueError("Empty mesh")
79
+ mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
80
+ o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
81
+
82
+ # 2. Data Prep
83
+ data, vox, surface_geodesic, t_norm, s_norm = create_single_data(mesh_filename)
84
+ data = data.to(device)
85
+
86
+ # 3. Predictions
87
+ mesh_norm = mesh_filename.replace("_remesh.obj", "_normalized.obj")
88
+ data = predict_joints(data, vox, jointNet, 1e-5, bandwidth=0.0429, mesh_filename=mesh_norm)
89
+ data = data.to(device)
90
+
91
+ skel = predict_skeleton(data, vox, rootNet, boneNet, mesh_filename=mesh_norm)
92
+ rig = predict_skinning(data, skel, skinNet, surface_geodesic, mesh_norm, subsampling=True)
93
+
94
+ # 4. Export
95
+ rig.normalize(s_norm, -t_norm)
96
+ final_rig = tranfer_to_ori_mesh(input_mesh_path, mesh_filename, rig)
97
+
98
+ out_path = os.path.join(working_dir, f"{base_name}_rig.txt")
99
+ final_rig.save(out_path)
100
+ return out_path
101
+
102
+ except Exception as e:
103
+ raise gr.Error(f"Error: {str(e)}")
104
+
105
+ # --- 5. Launch ---
106
+ iface = gr.Interface(
107
+ fn=rignet_inference,
108
+ inputs=gr.Model3D(label="Input .obj"),
109
+ outputs=gr.File(label="Rig Output .txt"),
110
+ title="RigNet Demo",
111
+ description="Upload a mesh to generate a rig."
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ # server_name="0.0.0.0" is CRITICAL for Docker Spaces
116
+ iface.launch(server_name="0.0.0.0", server_port=7860)