File size: 13,880 Bytes
24857f8
2df2c23
24857f8
2df2c23
24857f8
 
 
92da831
24857f8
 
2df2c23
24857f8
 
 
9550d40
24857f8
dea278b
24857f8
 
2df2c23
 
 
 
 
9550d40
 
 
 
 
 
 
 
 
 
 
96b4daf
9550d40
24857f8
053c7f6
 
24857f8
2df2c23
24857f8
92da831
2df2c23
24857f8
 
2df2c23
 
24857f8
 
92da831
 
24857f8
2df2c23
9550d40
 
 
96b4daf
24857f8
2df2c23
 
24857f8
2df2c23
24857f8
 
 
9550d40
 
 
 
2df2c23
 
 
 
24857f8
9550d40
2df2c23
24857f8
 
 
 
 
 
 
 
 
 
 
 
2df2c23
 
 
24857f8
 
 
 
 
9550d40
 
 
96b4daf
 
 
 
 
 
 
 
9550d40
2df2c23
 
 
24857f8
2df2c23
9550d40
 
 
 
24857f8
 
2df2c23
24857f8
 
 
2df2c23
24857f8
 
 
 
 
 
 
 
 
 
2df2c23
24857f8
 
2df2c23
 
 
 
 
24857f8
 
2df2c23
24857f8
2df2c23
24857f8
 
 
9550d40
 
 
 
24857f8
9550d40
24857f8
9550d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d85a56
 
 
 
 
 
 
 
 
9550d40
 
 
 
 
 
 
 
 
 
 
 
 
4183cba
9550d40
4183cba
9550d40
 
4183cba
 
9550d40
4183cba
 
 
053c7f6
4183cba
053c7f6
 
 
 
4183cba
 
9550d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58e94fa
 
 
9550d40
 
 
 
 
 
 
 
4183cba
 
 
 
 
 
 
24857f8
 
 
9550d40
 
24857f8
9550d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86ef72b
9550d40
 
 
 
 
 
 
 
 
 
 
 
19d8da0
 
 
9550d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""
FoundationPose model wrapper for inference.

This module wraps the FoundationPose library for 6D object pose estimation.
"""

import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import torch
import cv2

from masks import generate_naive_mask
logger = logging.getLogger(__name__)

# Add FoundationPose to Python path
FOUNDATIONPOSE_ROOT = Path("/app/FoundationPose")
if FOUNDATIONPOSE_ROOT.exists():
    sys.path.insert(0, str(FOUNDATIONPOSE_ROOT))

# Try to import FoundationPose modules
try:
    from estimater import FoundationPose
    from learning.training.predict_score import ScorePredictor
    from learning.training.predict_pose_refine import PoseRefinePredictor
    import nvdiffrast.torch as dr
    import trimesh
    FOUNDATIONPOSE_AVAILABLE = True
except ImportError as e:
    logger.warning(f"FoundationPose modules not available: {e}")
    FOUNDATIONPOSE_AVAILABLE = False
    trimesh = None




class FoundationPoseEstimator:
    """Wrapper for FoundationPose model."""

    def __init__(self, device: str = "cuda", weights_dir: str | None = None):
        """Initialize FoundationPose estimator.

        Args:
            device: Device to run inference on ('cuda' or 'cpu')
            weights_dir: Directory containing model weights
        """
        self.device = device
        if weights_dir is None:
            weights_dir = os.environ.get("FOUNDATIONPOSE_WEIGHTS_DIR", "/app/FoundationPose/weights")
        self.weights_dir = Path(weights_dir)
        self.registered_objects = {}
        self.scorer = None
        self.refiner = None
        self.glctx = None
        self.available = FOUNDATIONPOSE_AVAILABLE

        # Check if FoundationPose is available
        if not FOUNDATIONPOSE_ROOT.exists():
            raise RuntimeError(
                f"FoundationPose repository not found at {FOUNDATIONPOSE_ROOT}. "
                "Clone it with: git clone https://github.com/NVlabs/FoundationPose.git"
            )

        if not FOUNDATIONPOSE_AVAILABLE:
            logger.warning("FoundationPose modules not loaded - inference will not work")
            return

        # Check if weights exist
        if not self.weights_dir.exists() or not any(self.weights_dir.glob("**/*.pth")):
            logger.warning(f"No model weights found in {self.weights_dir}")
            logger.warning("Model will not work without weights")

        # Initialize predictors (lazy loading - only when needed)
        logger.info(f"FoundationPose estimator initialized (device: {device})")

    def register_object(
        self,
        object_id: str,
        reference_images: List[np.ndarray],
        camera_intrinsics: Optional[Dict] = None,
        mesh_path: Optional[str] = None
    ) -> bool:
        """Register an object for tracking.

        Args:
            object_id: Unique identifier for the object
            reference_images: List of RGB reference images (H, W, 3)
            camera_intrinsics: Camera parameters {fx, fy, cx, cy}
            mesh_path: Optional path to object mesh file

        Returns:
            True if registration successful
        """
        try:
            # Load mesh if provided
            mesh = None
            if mesh_path and Path(mesh_path).exists():
                if trimesh is None:
                    logger.warning("trimesh not available, skipping mesh load")
                else:
                    try:
                        mesh = trimesh.load(mesh_path)
                        logger.info(f"Loaded mesh for '{object_id}' from {mesh_path}")
                    except Exception as e:
                        logger.warning(f"Failed to load mesh: {e}")

            # Store object registration
            self.registered_objects[object_id] = {
                "num_references": len(reference_images),
                "camera_intrinsics": camera_intrinsics,
                "mesh_path": mesh_path,
                "mesh": mesh,
                "reference_images": reference_images,
                "estimator": None,  # Will be created lazily
                "pose_last": None   # Track last pose for temporal tracking
            }

            logger.info(f"✓ Registered object '{object_id}' with {len(reference_images)} reference images")
            return True

        except Exception as e:
            logger.error(f"Failed to register object '{object_id}': {e}", exc_info=True)
            return False

    def estimate_pose(
        self,
        object_id: str,
        rgb_image: np.ndarray,
        depth_image: Optional[np.ndarray] = None,
        mask: Optional[np.ndarray] = None,
        camera_intrinsics: Optional[Dict] = None
    ) -> Optional[Dict]:
        """Estimate 6D pose of registered object in image.

        Args:
            object_id: ID of object to detect
            rgb_image: RGB query image (H, W, 3)
            depth_image: Optional depth image (H, W)
            mask: Optional object mask (H, W)
            camera_intrinsics: Camera parameters {fx, fy, cx, cy}

        Returns:
            Pose dictionary with position, orientation, confidence or None
        """
        if object_id not in self.registered_objects:
            logger.error(f"Object '{object_id}' not registered")
            return None

        if not FOUNDATIONPOSE_AVAILABLE:
            logger.error("FoundationPose not available")
            return None

        try:
            obj_data = self.registered_objects[object_id]

            # Initialize predictors if not done yet
            if self.scorer is None:
                logger.info("Initializing score predictor...")
                self.scorer = ScorePredictor()
                logger.info("Initializing pose refiner...")
                self.refiner = PoseRefinePredictor()
                logger.info("Initializing CUDA rasterizer...")
                self.glctx = dr.RasterizeCudaContext()

            # Initialize object-specific estimator if not done yet
            if obj_data["estimator"] is None:
                logger.info(f"Creating FoundationPose estimator for '{object_id}'...")

                mesh = obj_data["mesh"]
                if mesh is not None:
                    # Model-based mode: use mesh
                    logger.info("Using model-based mode with mesh")
                    obj_data["estimator"] = FoundationPose(
                        model_pts=mesh.vertices,
                        model_normals=mesh.vertex_normals,
                        mesh=mesh,
                        scorer=self.scorer,
                        refiner=self.refiner,
                        glctx=self.glctx,
                        debug=0
                    )
                else:
                    # Model-free mode: requires 3D reconstruction from reference images
                    # This would typically use structure-from-motion (SfM) to create a mesh
                    # For now, this is not implemented
                    logger.error("Model-free mode not yet implemented")
                    logger.error("To use FoundationPose, please provide a 3D mesh (.obj, .stl, .ply)")
                    logger.error("You can:")
                    logger.error("  1. Use CAD-Based initialization with your object's 3D model")
                    logger.error("  2. Create a mesh from photos using photogrammetry tools (e.g., Meshroom, COLMAP)")
                    logger.error("  3. Scan the object with a 3D scanner")
                    return None

            estimator = obj_data["estimator"]

            # Prepare camera intrinsics matrix
            K = self._get_camera_matrix(camera_intrinsics or obj_data["camera_intrinsics"])
            if K is None:
                logger.error("Camera intrinsics required")
                return None

            # Generate or use depth if not provided
            if depth_image is None:
                # Create dummy depth for model-based case
                # Use a more realistic depth distribution centered at 0.5m with some variation
                depth_image = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.float32) * 0.5
                logger.warning("Using dummy depth image - for better results, provide actual depth data")

            # Generate mask if not provided
            mask_was_generated = False
            debug_mask = None
            if mask is None:
                # Use automatic foreground segmentation based on brightness
                # This works well for light objects on dark backgrounds
                logger.info("Generating automatic object mask from image")
                mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(rgb_image)
                logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image")
                if fallback_full_image:
                    logger.warning(
                        f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image"
                    )

                mask_was_generated = True

            # First frame or lost tracking: register
            if obj_data["pose_last"] is None:
                logger.info("Running registration (first frame)...")
                pose = estimator.register(
                    K=K,
                    rgb=rgb_image,
                    depth=depth_image,
                    ob_mask=mask,
                    iteration=5  # Number of refinement iterations
                )
            else:
                # Subsequent frames: track
                pose = estimator.track_one(
                    rgb=rgb_image,
                    depth=depth_image,
                    K=K,
                    iteration=2  # Fewer iterations for tracking
                )

            # Store pose for next frame (move to CPU if it's a tensor)
            if torch.is_tensor(pose):
                pose = pose.detach().cpu().numpy()
            obj_data["pose_last"] = pose

            if pose is None:
                logger.warning("Pose estimation returned None")
                return None

            # Convert pose to our format
            # pose is a 4x4 transformation matrix
            result = self._format_pose_output(pose)

            # Add debug mask if it was auto-generated
            if mask_was_generated and debug_mask is not None:
                result["debug_mask"] = debug_mask

            return result

        except Exception as e:
            logger.error(f"Pose estimation failed: {e}", exc_info=True)
            import traceback
            traceback.print_exc()
            return None

    def _get_camera_matrix(self, intrinsics: Optional[Dict]) -> Optional[np.ndarray]:
        """Convert intrinsics dict to camera matrix."""
        if intrinsics is None:
            return None

        fx = intrinsics.get("fx")
        fy = intrinsics.get("fy")
        cx = intrinsics.get("cx")
        cy = intrinsics.get("cy")

        if None in [fx, fy, cx, cy]:
            return None

        K = np.array([
            [fx, 0, cx],
            [0, fy, cy],
            [0, 0, 1]
        ], dtype=np.float64)

        return K

    def _format_pose_output(self, pose_matrix: np.ndarray) -> Dict:
        """Convert 4x4 pose matrix to output format.

        Args:
            pose_matrix: 4x4 transformation matrix

        Returns:
            Dictionary with position, orientation (quaternion), and confidence
        """
        if torch.is_tensor(pose_matrix):
            pose_matrix = pose_matrix.detach().cpu().numpy()

        # Extract translation
        translation = pose_matrix[:3, 3]

        # Extract rotation matrix
        rotation_matrix = pose_matrix[:3, :3]

        # Convert rotation matrix to quaternion
        # Using Shepperd's method for numerical stability
        trace = np.trace(rotation_matrix)

        if trace > 0:
            s = np.sqrt(trace + 1.0) * 2
            w = 0.25 * s
            x = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
            y = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
            z = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
        elif rotation_matrix[0, 0] > rotation_matrix[1, 1] and rotation_matrix[0, 0] > rotation_matrix[2, 2]:
            s = np.sqrt(1.0 + rotation_matrix[0, 0] - rotation_matrix[1, 1] - rotation_matrix[2, 2]) * 2
            w = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s
            x = 0.25 * s
            y = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
            z = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
        elif rotation_matrix[1, 1] > rotation_matrix[2, 2]:
            s = np.sqrt(1.0 + rotation_matrix[1, 1] - rotation_matrix[0, 0] - rotation_matrix[2, 2]) * 2
            w = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s
            x = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s
            y = 0.25 * s
            z = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
        else:
            s = np.sqrt(1.0 + rotation_matrix[2, 2] - rotation_matrix[0, 0] - rotation_matrix[1, 1]) * 2
            w = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s
            x = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s
            y = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s
            z = 0.25 * s

        return {
            "position": {
                "x": float(translation[0]),
                "y": float(translation[1]),
                "z": float(translation[2])
            },
            "orientation": {
                "w": float(w),
                "x": float(x),
                "y": float(y),
                "z": float(z)
            },
            "confidence": 1.0,  # FoundationPose doesn't provide explicit confidence
            "pose_matrix": pose_matrix.tolist()
        }