Georg commited on
Commit
9550d40
·
1 Parent(s): f592ee6

Implement real FoundationPose inference with model-based pose estimation

Browse files
Files changed (1) hide show
  1. estimator.py +201 -9
estimator.py CHANGED
@@ -11,6 +11,7 @@ from typing import Dict, List, Optional
11
 
12
  import numpy as np
13
  import torch
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -19,6 +20,18 @@ FOUNDATIONPOSE_ROOT = Path("/app/FoundationPose")
19
  if FOUNDATIONPOSE_ROOT.exists():
20
  sys.path.insert(0, str(FOUNDATIONPOSE_ROOT))
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class FoundationPoseEstimator:
24
  """Wrapper for FoundationPose model."""
@@ -32,8 +45,10 @@ class FoundationPoseEstimator:
32
  """
33
  self.device = device
34
  self.weights_dir = Path(weights_dir)
35
- self.model = None
36
  self.registered_objects = {}
 
 
 
37
 
38
  # Check if FoundationPose is available
39
  if not FOUNDATIONPOSE_ROOT.exists():
@@ -42,11 +57,16 @@ class FoundationPoseEstimator:
42
  "Clone it with: git clone https://github.com/NVlabs/FoundationPose.git"
43
  )
44
 
 
 
 
 
45
  # Check if weights exist
46
  if not self.weights_dir.exists() or not any(self.weights_dir.glob("**/*.pth")):
47
  logger.warning(f"No model weights found in {self.weights_dir}")
48
  logger.warning("Model will not work without weights")
49
 
 
50
  logger.info(f"FoundationPose estimator initialized (device: {device})")
51
 
52
  def register_object(
@@ -68,12 +88,24 @@ class FoundationPoseEstimator:
68
  True if registration successful
69
  """
70
  try:
 
 
 
 
 
 
 
 
 
71
  # Store object registration
72
  self.registered_objects[object_id] = {
73
  "num_references": len(reference_images),
74
  "camera_intrinsics": camera_intrinsics,
75
  "mesh_path": mesh_path,
76
- "reference_images": reference_images # Keep for now
 
 
 
77
  }
78
 
79
  logger.info(f"✓ Registered object '{object_id}' with {len(reference_images)} reference images")
@@ -107,16 +139,176 @@ class FoundationPoseEstimator:
107
  logger.error(f"Object '{object_id}' not registered")
108
  return None
109
 
 
 
 
 
110
  try:
111
- # TODO: Implement actual FoundationPose inference
112
- # This is a placeholder that would need to:
113
- # 1. Load the FoundationPose model if not loaded
114
- # 2. Run pose estimation on the query image
115
- # 3. Return the estimated pose
116
 
117
- logger.warning("FoundationPose inference not yet implemented - returning None")
118
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  except Exception as e:
121
  logger.error(f"Pose estimation failed: {e}", exc_info=True)
 
 
122
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  import numpy as np
13
  import torch
14
+ import cv2
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
20
  if FOUNDATIONPOSE_ROOT.exists():
21
  sys.path.insert(0, str(FOUNDATIONPOSE_ROOT))
22
 
23
+ # Try to import FoundationPose modules
24
+ try:
25
+ from estimater import FoundationPose
26
+ from learning.training.predict_score import ScorePredictor
27
+ from learning.training.predict_pose_refine import PoseRefinePredictor
28
+ import nvdiffrast.torch as dr
29
+ import trimesh
30
+ FOUNDATIONPOSE_AVAILABLE = True
31
+ except ImportError as e:
32
+ logger.warning(f"FoundationPose modules not available: {e}")
33
+ FOUNDATIONPOSE_AVAILABLE = False
34
+
35
 
36
  class FoundationPoseEstimator:
37
  """Wrapper for FoundationPose model."""
 
45
  """
46
  self.device = device
47
  self.weights_dir = Path(weights_dir)
 
48
  self.registered_objects = {}
49
+ self.scorer = None
50
+ self.refiner = None
51
+ self.glctx = None
52
 
53
  # Check if FoundationPose is available
54
  if not FOUNDATIONPOSE_ROOT.exists():
 
57
  "Clone it with: git clone https://github.com/NVlabs/FoundationPose.git"
58
  )
59
 
60
+ if not FOUNDATIONPOSE_AVAILABLE:
61
+ logger.warning("FoundationPose modules not loaded - inference will not work")
62
+ return
63
+
64
  # Check if weights exist
65
  if not self.weights_dir.exists() or not any(self.weights_dir.glob("**/*.pth")):
66
  logger.warning(f"No model weights found in {self.weights_dir}")
67
  logger.warning("Model will not work without weights")
68
 
69
+ # Initialize predictors (lazy loading - only when needed)
70
  logger.info(f"FoundationPose estimator initialized (device: {device})")
71
 
72
  def register_object(
 
88
  True if registration successful
89
  """
90
  try:
91
+ # Load mesh if provided
92
+ mesh = None
93
+ if mesh_path and Path(mesh_path).exists():
94
+ try:
95
+ mesh = trimesh.load(mesh_path)
96
+ logger.info(f"Loaded mesh for '{object_id}' from {mesh_path}")
97
+ except Exception as e:
98
+ logger.warning(f"Failed to load mesh: {e}")
99
+
100
  # Store object registration
101
  self.registered_objects[object_id] = {
102
  "num_references": len(reference_images),
103
  "camera_intrinsics": camera_intrinsics,
104
  "mesh_path": mesh_path,
105
+ "mesh": mesh,
106
+ "reference_images": reference_images,
107
+ "estimator": None, # Will be created lazily
108
+ "pose_last": None # Track last pose for temporal tracking
109
  }
110
 
111
  logger.info(f"✓ Registered object '{object_id}' with {len(reference_images)} reference images")
 
139
  logger.error(f"Object '{object_id}' not registered")
140
  return None
141
 
142
+ if not FOUNDATIONPOSE_AVAILABLE:
143
+ logger.error("FoundationPose not available")
144
+ return None
145
+
146
  try:
147
+ obj_data = self.registered_objects[object_id]
 
 
 
 
148
 
149
+ # Initialize predictors if not done yet
150
+ if self.scorer is None:
151
+ logger.info("Initializing score predictor...")
152
+ self.scorer = ScorePredictor()
153
+ logger.info("Initializing pose refiner...")
154
+ self.refiner = PoseRefinePredictor()
155
+ logger.info("Initializing CUDA rasterizer...")
156
+ self.glctx = dr.RasterizeCudaContext()
157
+
158
+ # Initialize object-specific estimator if not done yet
159
+ if obj_data["estimator"] is None:
160
+ logger.info(f"Creating FoundationPose estimator for '{object_id}'...")
161
+
162
+ mesh = obj_data["mesh"]
163
+ if mesh is not None:
164
+ # Model-based mode: use mesh
165
+ logger.info("Using model-based mode with mesh")
166
+ obj_data["estimator"] = FoundationPose(
167
+ model_pts=mesh.vertices,
168
+ model_normals=mesh.vertex_normals,
169
+ mesh=mesh,
170
+ scorer=self.scorer,
171
+ refiner=self.refiner,
172
+ glctx=self.glctx,
173
+ debug=0
174
+ )
175
+ else:
176
+ # Model-free mode: would need reference-based initialization
177
+ # For now, return error
178
+ logger.error("Model-free mode not yet implemented - mesh required")
179
+ return None
180
+
181
+ estimator = obj_data["estimator"]
182
+
183
+ # Prepare camera intrinsics matrix
184
+ K = self._get_camera_matrix(camera_intrinsics or obj_data["camera_intrinsics"])
185
+ if K is None:
186
+ logger.error("Camera intrinsics required")
187
+ return None
188
+
189
+ # Generate or use depth if not provided
190
+ if depth_image is None:
191
+ # Create dummy depth for model-based case
192
+ depth_image = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.float32) * 0.5
193
+
194
+ # Generate mask if not provided
195
+ if mask is None:
196
+ # Use simple foreground detection or full image
197
+ mask = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=bool)
198
+
199
+ # First frame or lost tracking: register
200
+ if obj_data["pose_last"] is None:
201
+ logger.info("Running registration (first frame)...")
202
+ pose = estimator.register(
203
+ K=K,
204
+ rgb=rgb_image,
205
+ depth=depth_image,
206
+ ob_mask=mask,
207
+ iteration=5 # Number of refinement iterations
208
+ )
209
+ else:
210
+ # Subsequent frames: track
211
+ pose = estimator.track_one(
212
+ rgb=rgb_image,
213
+ depth=depth_image,
214
+ K=K,
215
+ iteration=2 # Fewer iterations for tracking
216
+ )
217
+
218
+ # Store pose for next frame
219
+ obj_data["pose_last"] = pose
220
+
221
+ if pose is None:
222
+ logger.warning("Pose estimation returned None")
223
+ return None
224
+
225
+ # Convert pose to our format
226
+ # pose is a 4x4 transformation matrix
227
+ return self._format_pose_output(pose)
228
 
229
  except Exception as e:
230
  logger.error(f"Pose estimation failed: {e}", exc_info=True)
231
+ import traceback
232
+ traceback.print_exc()
233
  return None
234
+
235
+ def _get_camera_matrix(self, intrinsics: Optional[Dict]) -> Optional[np.ndarray]:
236
+ """Convert intrinsics dict to camera matrix."""
237
+ if intrinsics is None:
238
+ return None
239
+
240
+ fx = intrinsics.get("fx")
241
+ fy = intrinsics.get("fy")
242
+ cx = intrinsics.get("cx")
243
+ cy = intrinsics.get("cy")
244
+
245
+ if None in [fx, fy, cx, cy]:
246
+ return None
247
+
248
+ K = np.array([
249
+ [fx, 0, cx],
250
+ [0, fy, cy],
251
+ [0, 0, 1]
252
+ ], dtype=np.float32)
253
+
254
+ return K
255
+
256
+ def _format_pose_output(self, pose_matrix: np.ndarray) -> Dict:
257
+ """Convert 4x4 pose matrix to output format.
258
+
259
+ Args:
260
+ pose_matrix: 4x4 transformation matrix
261
+
262
+ Returns:
263
+ Dictionary with position, orientation (quaternion), and confidence
264
+ """
265
+ # Extract translation
266
+ translation = pose_matrix[:3, 3]
267
+
268
+ # Extract rotation matrix
269
+ rotation_matrix = pose_matrix[:3, :3]
270
+
271
+ # Convert rotation matrix to quaternion
272
+ # Using Shepperd's method for numerical stability
273
+ trace = np.trace(rotation_matrix)
274
+
275
+ if trace > 0:
276
+ s = np.sqrt(trace + 1.0) * 2
277
+ w = 0.25 * s
278
+ x = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
279
+ y = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
280
+ z = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
281
+ elif rotation_matrix[0, 0] > rotation_matrix[1, 1] and rotation_matrix[0, 0] > rotation_matrix[2, 2]:
282
+ s = np.sqrt(1.0 + rotation_matrix[0, 0] - rotation_matrix[1, 1] - rotation_matrix[2, 2]) * 2
283
+ w = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
284
+ x = 0.25 * s
285
+ y = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
286
+ z = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
287
+ elif rotation_matrix[1, 1] > rotation_matrix[2, 2]:
288
+ s = np.sqrt(1.0 + rotation_matrix[1, 1] - rotation_matrix[0, 0] - rotation_matrix[2, 2]) * 2
289
+ w = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
290
+ x = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
291
+ y = 0.25 * s
292
+ z = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
293
+ else:
294
+ s = np.sqrt(1.0 + rotation_matrix[2, 2] - rotation_matrix[0, 0] - rotation_matrix[1, 1]) * 2
295
+ w = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
296
+ x = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
297
+ y = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
298
+ z = 0.25 * s
299
+
300
+ return {
301
+ "position": {
302
+ "x": float(translation[0]),
303
+ "y": float(translation[1]),
304
+ "z": float(translation[2])
305
+ },
306
+ "orientation": {
307
+ "w": float(w),
308
+ "x": float(x),
309
+ "y": float(y),
310
+ "z": float(z)
311
+ },
312
+ "confidence": 1.0, # FoundationPose doesn't provide explicit confidence
313
+ "pose_matrix": pose_matrix.tolist()
314
+ }