ckc99u commited on
Commit
2295ea6
·
verified ·
1 Parent(s): 9879dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -83
app.py CHANGED
@@ -6,34 +6,50 @@ import torch
6
  import numpy as np
7
  import open3d as o3d
8
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # --- 1. Setup Environment & Paths ---
11
- # Add RigNet submodule to python path so internal imports (e.g., 'from models import...') work
12
  RIGNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "RigNet")
13
  if RIGNET_PATH not in sys.path:
14
  sys.path.append(RIGNET_PATH)
15
 
16
- # Ensure binvox executable is available in the root (required by RigNet's os.system call)
17
  BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
18
- if platform.system() == "Windows":
19
- BINVOX_SRC += ".exe"
20
- BINVOX_DEST = "binvox.exe"
21
- else:
22
- BINVOX_DEST = "binvox"
23
 
24
  if os.path.exists(BINVOX_SRC):
25
- # Copy to root so ./binvox works
26
  shutil.copy(BINVOX_SRC, BINVOX_DEST)
27
- # Make executable
28
- if platform.system() != "Windows":
29
- os.system(f"chmod +x {BINVOX_DEST}")
30
  else:
31
- print(f"Warning: binvox not found at {BINVOX_SRC}. Inference may fail.")
32
 
33
  # --- 2. Import RigNet Modules ---
34
  try:
 
 
35
  from quick_start import (
36
- create_single_data, predict_joints, predict_skeleton,
 
37
  predict_skinning, tranfer_to_ori_mesh
38
  )
39
  from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
@@ -41,122 +57,101 @@ try:
41
  from models.PairCls_GCN import PairCls as BONENET
42
  from models.SKINNING import SKINNET
43
  except ImportError as e:
44
- print("Error importing RigNet modules. Ensure the 'RigNet' folder is correctly placed.")
45
- raise e
46
 
47
- # --- 3. Load Models Globally ---
48
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
  print(f"Loading RigNet models on {device}...")
50
 
51
  def load_checkpoint(model, filename):
52
- # Checkpoints are located inside RigNet/checkpoints/
53
  filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
54
- if not os.path.exists(filepath):
55
- raise FileNotFoundError(f"Checkpoint not found: {filepath}")
56
  checkpoint = torch.load(filepath, map_location=device)
57
  model.load_state_dict(checkpoint['state_dict'])
58
  return model
59
 
60
  # Initialize models
61
- jointNet = JOINTNET().to(device)
62
- jointNet.eval()
63
  load_checkpoint(jointNet, 'gcn_meanshift/model_best.pth.tar')
64
 
65
- rootNet = ROOTNET().to(device)
66
- rootNet.eval()
67
  load_checkpoint(rootNet, 'rootnet/model_best.pth.tar')
68
 
69
- boneNet = BONENET().to(device)
70
- boneNet.eval()
71
  load_checkpoint(boneNet, 'bonenet/model_best.pth.tar')
72
 
73
- skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True).to(device)
74
- skinNet.eval()
75
  load_checkpoint(skinNet, 'skinnet/model_best.pth.tar')
76
 
77
- print("Models loaded successfully.")
78
 
79
- # --- 4. Inference Pipeline ---
80
- def rignet_inference(input_mesh_path):
81
- """
82
- Main pipeline adapted from quick_start.py logic.
83
- """
84
- if not input_mesh_path:
85
- return None
 
 
 
 
 
 
 
86
 
87
- # Work in a temp folder derived from the input path
 
 
 
88
  working_dir = os.path.dirname(input_mesh_path)
89
  base_name = os.path.basename(input_mesh_path).replace(".obj", "")
90
-
91
- # Prepare filenames
92
  mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
93
 
94
  print(f"Processing: {input_mesh_path}")
95
-
96
  try:
97
- # 1. Preprocessing / Decimation
98
  mesh_ori = o3d.io.read_triangle_mesh(input_mesh_path)
99
- if len(np.asarray(mesh_ori.vertices)) == 0:
100
- raise ValueError("Empty mesh uploaded.")
101
-
102
- # Simplify mesh (approx 4k vertices for inference)
103
  mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
104
  o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
105
 
106
- # 2. Create Data (Voxelization + Geodesic calculation)
107
- data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
108
  data = data.to(device)
109
 
110
- # 3. Pipeline Predictions
111
- # Default parameters from quick_start
112
- bandwidth = 0.0429
113
- threshold = 1e-5
114
-
115
- mesh_norm_path = mesh_filename.replace("_remesh.obj", "_normalized.obj")
116
 
117
  print("Predicting joints...")
118
- data = predict_joints(
119
- data, vox, jointNet, threshold, bandwidth=bandwidth,
120
- mesh_filename=mesh_norm_path
121
- )
122
  data = data.to(device)
123
 
124
  print("Predicting connectivity...")
125
- pred_skeleton = predict_skeleton(
126
- data, vox, rootNet, boneNet,
127
- mesh_filename=mesh_norm_path
128
- )
129
 
130
  print("Predicting skinning...")
131
- pred_rig = predict_skinning(
132
- data, pred_skeleton, skinNet, surface_geodesic,
133
- mesh_norm_path, subsampling=True
134
- )
135
-
136
- # 4. Post-processing
137
- pred_rig.normalize(scale_normalize, -translation_normalize)
138
- final_rig = tranfer_to_ori_mesh(input_mesh_path, mesh_filename, pred_rig)
139
-
140
- # Save output to a text file
141
- output_path = os.path.join(working_dir, f"{base_name}_rig.txt")
142
- final_rig.save(output_path)
143
 
144
- return output_path
 
 
145
 
146
  except Exception as e:
147
- print(f"Error during inference: {e}")
148
  raise gr.Error(f"Processing failed: {str(e)}")
149
 
150
- # --- 5. Gradio Interface ---
151
- title = "RigNet: Neural Rigging"
152
- description = "Upload an .obj file. The model will generate a skeleton and skinning weights, returned as a text file."
153
-
154
  iface = gr.Interface(
155
  fn=rignet_inference,
156
- inputs=gr.Model3D(label="Input Mesh (.obj)"),
157
- outputs=gr.File(label="Download Rig (.txt)"),
158
- title=title,
159
- description=description
160
  )
161
 
162
  if __name__ == "__main__":
 
6
  import numpy as np
7
  import open3d as o3d
8
  import gradio as gr
9
+ from torch_geometric.data import Data
10
+
11
+ # --- 0. Compatibility Patch for PyG 2.0+ ---
12
+ # RigNet uses older PyG where you could assign arbitrary attributes (data.joints = ...).
13
+ # In PyG 2.0+, 'Data' stores everything in a store. We subclass to allow dot-notation access.
14
+ class RigNetData(Data):
15
+ def __setattr__(self, key, value):
16
+ if key in ['joints', 'pairs', 'pair_attr', 'joints_batch', 'pairs_batch', 'skin_input']:
17
+ self[key] = value
18
+ else:
19
+ super().__setattr__(key, value)
20
+
21
+ def __getattr__(self, key):
22
+ # Fallback to getting from the dictionary if standard attribute fails
23
+ if key in self.keys:
24
+ return self[key]
25
+ return super().__getattr__(key)
26
+
27
+ # Monkey-patch RigNet's create_single_data later to use this class,
28
+ # or just convert the object after creation.
29
 
30
  # --- 1. Setup Environment & Paths ---
 
31
  RIGNET_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "RigNet")
32
  if RIGNET_PATH not in sys.path:
33
  sys.path.append(RIGNET_PATH)
34
 
35
+ # Ensure binvox executable
36
  BINVOX_SRC = os.path.join(RIGNET_PATH, "binvox")
37
+ BINVOX_DEST = "binvox.exe" if platform.system() == "Windows" else "binvox"
38
+ if platform.system() == "Windows": BINVOX_SRC += ".exe"
 
 
 
39
 
40
  if os.path.exists(BINVOX_SRC):
 
41
  shutil.copy(BINVOX_SRC, BINVOX_DEST)
42
+ if platform.system() != "Windows": os.system(f"chmod +x {BINVOX_DEST}")
 
 
43
  else:
44
+ print(f"Warning: binvox not found at {BINVOX_SRC}")
45
 
46
  # --- 2. Import RigNet Modules ---
47
  try:
48
+ # We need to intercept imports to inject our patched Data class if possible,
49
+ # but easier to just wrap the result of create_single_data.
50
  from quick_start import (
51
+ create_single_data as original_create_single_data, # Rename to wrap
52
+ predict_joints, predict_skeleton,
53
  predict_skinning, tranfer_to_ori_mesh
54
  )
55
  from models.GCN import JOINTNET_MASKNET_MEANSHIFT as JOINTNET
 
57
  from models.PairCls_GCN import PairCls as BONENET
58
  from models.SKINNING import SKINNET
59
  except ImportError as e:
60
+ print(f"Error importing RigNet: {e}")
 
61
 
62
+ # --- 3. Load Models ---
63
+ device = torch.device("cpu")
64
  print(f"Loading RigNet models on {device}...")
65
 
66
  def load_checkpoint(model, filename):
 
67
  filepath = os.path.join(RIGNET_PATH, "checkpoints", filename)
68
+ # map_location=device is crucial for CPU-only spaces
 
69
  checkpoint = torch.load(filepath, map_location=device)
70
  model.load_state_dict(checkpoint['state_dict'])
71
  return model
72
 
73
  # Initialize models
74
+ jointNet = JOINTNET().to(device); jointNet.eval()
 
75
  load_checkpoint(jointNet, 'gcn_meanshift/model_best.pth.tar')
76
 
77
+ rootNet = ROOTNET().to(device); rootNet.eval()
 
78
  load_checkpoint(rootNet, 'rootnet/model_best.pth.tar')
79
 
80
+ boneNet = BONENET().to(device); boneNet.eval()
 
81
  load_checkpoint(boneNet, 'bonenet/model_best.pth.tar')
82
 
83
+ skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True).to(device); skinNet.eval()
 
84
  load_checkpoint(skinNet, 'skinnet/model_best.pth.tar')
85
 
86
+ print("Models loaded.")
87
 
88
+ # --- 4. Wrapper to fix Data object ---
89
+ def create_single_data_patched(mesh_filename):
90
+ # Call original function
91
+ data, vox, surf_geo, t_norm, s_norm = original_create_single_data(mesh_filename)
92
+
93
+ # Convert to our flexible class
94
+ new_data = RigNetData(
95
+ x=data.x,
96
+ pos=data.pos,
97
+ tpl_edge_index=data.tpl_edge_index,
98
+ geo_edge_index=data.geo_edge_index,
99
+ batch=data.batch
100
+ )
101
+ return new_data, vox, surf_geo, t_norm, s_norm
102
 
103
+ # --- 5. Inference Pipeline ---
104
+ def rignet_inference(input_mesh_path):
105
+ if not input_mesh_path: return None
106
+
107
  working_dir = os.path.dirname(input_mesh_path)
108
  base_name = os.path.basename(input_mesh_path).replace(".obj", "")
 
 
109
  mesh_filename = os.path.join(working_dir, f"{base_name}_remesh.obj")
110
 
111
  print(f"Processing: {input_mesh_path}")
 
112
  try:
113
+ # 1. Preprocess
114
  mesh_ori = o3d.io.read_triangle_mesh(input_mesh_path)
115
+ if len(np.asarray(mesh_ori.vertices)) == 0: raise ValueError("Empty mesh")
 
 
 
116
  mesh_remesh = mesh_ori.simplify_quadric_decimation(4000)
117
  o3d.io.write_triangle_mesh(mesh_filename, mesh_remesh)
118
 
119
+ # 2. Create Data (Patched)
120
+ data, vox, surface_geodesic, t_norm, s_norm = create_single_data_patched(mesh_filename)
121
  data = data.to(device)
122
 
123
+ # 3. Inference
124
+ mesh_norm = mesh_filename.replace("_remesh.obj", "_normalized.obj")
 
 
 
 
125
 
126
  print("Predicting joints...")
127
+ data = predict_joints(data, vox, jointNet, 1e-5, bandwidth=0.0429, mesh_filename=mesh_norm)
 
 
 
128
  data = data.to(device)
129
 
130
  print("Predicting connectivity...")
131
+ skel = predict_skeleton(data, vox, rootNet, boneNet, mesh_filename=mesh_norm)
 
 
 
132
 
133
  print("Predicting skinning...")
134
+ rig = predict_skinning(data, skel, skinNet, surface_geodesic, mesh_norm, subsampling=True)
135
+
136
+ # 4. Export
137
+ rig.normalize(s_norm, -t_norm)
138
+ final_rig = tranfer_to_ori_mesh(input_mesh_path, mesh_filename, rig)
 
 
 
 
 
 
 
139
 
140
+ out_path = os.path.join(working_dir, f"{base_name}_rig.txt")
141
+ final_rig.save(out_path)
142
+ return out_path
143
 
144
  except Exception as e:
145
+ print(f"Error: {e}")
146
  raise gr.Error(f"Processing failed: {str(e)}")
147
 
148
+ # --- 6. Launch ---
 
 
 
149
  iface = gr.Interface(
150
  fn=rignet_inference,
151
+ inputs=gr.Model3D(label="Input .obj"),
152
+ outputs=gr.File(label="Rig Output .txt"),
153
+ title="RigNet Demo",
154
+ description="Upload a mesh to generate a rig. (Uses PyTorch 2.1 CPU)"
155
  )
156
 
157
  if __name__ == "__main__":