ckc99u commited on
Commit
dd0269a
·
verified ·
1 Parent(s): 23ce2fb

Update RigNet/quick_start.py

Browse files
Files changed (1) hide show
  1. RigNet/quick_start.py +22 -11
RigNet/quick_start.py CHANGED
@@ -110,18 +110,26 @@ def create_single_data(mesh_filename):
110
  # Create voxel grid
111
  voxel_grid = mesh_tri.voxelized(pitch=pitch)
112
 
113
- # Ensure it's exactly 88x88x88 by padding/cropping
114
  vox_matrix = voxel_grid.matrix
115
  current_shape = vox_matrix.shape
116
 
117
- # Pad to 88x88x88 if smaller
118
- if any(s < 88 for s in current_shape):
119
- padded = np.zeros((88, 88, 88), dtype=bool)
120
- padded[:current_shape[0], :current_shape[1], :current_shape[2]] = vox_matrix
121
- vox_matrix = padded
122
- # Crop to 88x88x88 if larger
123
- elif any(s > 88 for s in current_shape):
124
- vox_matrix = vox_matrix[:88, :88, :88]
 
 
 
 
 
 
 
 
125
 
126
  # Create binvox-compatible object with ALL required attributes
127
  class Voxels:
@@ -130,14 +138,14 @@ def create_single_data(mesh_filename):
130
  self.dims = dims
131
  self.translate = translate
132
  self.scale = scale
133
- self.axis_order = axis_order # Required by binvox_rw
134
 
135
  vox_obj = Voxels(
136
  data=vox_matrix,
137
  dims=[88, 88, 88],
138
  translate=[0.0, 0.0, 0.0],
139
  scale=1.0,
140
- axis_order='xyz' # Standard axis order
141
  )
142
 
143
  # Save as binvox format for caching
@@ -148,6 +156,8 @@ def create_single_data(mesh_filename):
148
 
149
  except Exception as e:
150
  print(f" ERROR: Trimesh voxelization failed: {e}")
 
 
151
  raise Exception(f"Voxelization failed: {e}")
152
 
153
  # Load voxel data
@@ -157,6 +167,7 @@ def create_single_data(mesh_filename):
157
  data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
158
  return data, vox, surface_geodesic, translation_normalize, scale_normalize
159
 
 
160
  # def create_single_data(mesh_filename):
161
  # """
162
  # create input data for the network. The data is wrapped by Data structure in pytorch-geometric library
 
110
  # Create voxel grid
111
  voxel_grid = mesh_tri.voxelized(pitch=pitch)
112
 
113
+ # Get current voxel matrix
114
  vox_matrix = voxel_grid.matrix
115
  current_shape = vox_matrix.shape
116
 
117
+ print(f" Original voxel shape: {current_shape}")
118
+
119
+ # Resize to exactly 88x88x88 by padding/cropping each dimension
120
+ target_shape = (88, 88, 88)
121
+ resized = np.zeros(target_shape, dtype=bool)
122
+
123
+ # Calculate how much to copy in each dimension
124
+ x_size = min(current_shape[0], target_shape[0])
125
+ y_size = min(current_shape[1], target_shape[1])
126
+ z_size = min(current_shape[2], target_shape[2])
127
+
128
+ # Copy the overlapping region
129
+ resized[:x_size, :y_size, :z_size] = vox_matrix[:x_size, :y_size, :z_size]
130
+
131
+ vox_matrix = resized
132
+ print(f" Resized voxel shape: {vox_matrix.shape}")
133
 
134
  # Create binvox-compatible object with ALL required attributes
135
  class Voxels:
 
138
  self.dims = dims
139
  self.translate = translate
140
  self.scale = scale
141
+ self.axis_order = axis_order
142
 
143
  vox_obj = Voxels(
144
  data=vox_matrix,
145
  dims=[88, 88, 88],
146
  translate=[0.0, 0.0, 0.0],
147
  scale=1.0,
148
+ axis_order='xyz'
149
  )
150
 
151
  # Save as binvox format for caching
 
156
 
157
  except Exception as e:
158
  print(f" ERROR: Trimesh voxelization failed: {e}")
159
+ import traceback
160
+ traceback.print_exc()
161
  raise Exception(f"Voxelization failed: {e}")
162
 
163
  # Load voxel data
 
167
  data = Data(x=v[:, 3:6], pos=v[:, 0:3], tpl_edge_index=tpl_e, geo_edge_index=geo_e, batch=batch)
168
  return data, vox, surface_geodesic, translation_normalize, scale_normalize
169
 
170
+
171
  # def create_single_data(mesh_filename):
172
  # """
173
  # create input data for the network. The data is wrapped by Data structure in pytorch-geometric library