""" FoundationPose model wrapper for inference. This module wraps the FoundationPose library for 6D object pose estimation. """ import logging import os import sys from pathlib import Path from typing import Dict, List, Optional import numpy as np import torch import cv2 from masks import generate_naive_mask logger = logging.getLogger(__name__) # Add FoundationPose to Python path FOUNDATIONPOSE_ROOT = Path("/app/FoundationPose") if FOUNDATIONPOSE_ROOT.exists(): sys.path.insert(0, str(FOUNDATIONPOSE_ROOT)) # Try to import FoundationPose modules try: from estimater import FoundationPose from learning.training.predict_score import ScorePredictor from learning.training.predict_pose_refine import PoseRefinePredictor import nvdiffrast.torch as dr import trimesh FOUNDATIONPOSE_AVAILABLE = True except ImportError as e: logger.warning(f"FoundationPose modules not available: {e}") FOUNDATIONPOSE_AVAILABLE = False trimesh = None class FoundationPoseEstimator: """Wrapper for FoundationPose model.""" def __init__(self, device: str = "cuda", weights_dir: str | None = None): """Initialize FoundationPose estimator. Args: device: Device to run inference on ('cuda' or 'cpu') weights_dir: Directory containing model weights """ self.device = device if weights_dir is None: weights_dir = os.environ.get("FOUNDATIONPOSE_WEIGHTS_DIR", "/app/FoundationPose/weights") self.weights_dir = Path(weights_dir) self.registered_objects = {} self.scorer = None self.refiner = None self.glctx = None self.available = FOUNDATIONPOSE_AVAILABLE # Check if FoundationPose is available if not FOUNDATIONPOSE_ROOT.exists(): raise RuntimeError( f"FoundationPose repository not found at {FOUNDATIONPOSE_ROOT}. " "Clone it with: git clone https://github.com/NVlabs/FoundationPose.git" ) if not FOUNDATIONPOSE_AVAILABLE: logger.warning("FoundationPose modules not loaded - inference will not work") return # Check if weights exist if not self.weights_dir.exists() or not any(self.weights_dir.glob("**/*.pth")): logger.warning(f"No model weights found in {self.weights_dir}") logger.warning("Model will not work without weights") # Initialize predictors (lazy loading - only when needed) logger.info(f"FoundationPose estimator initialized (device: {device})") def register_object( self, object_id: str, reference_images: List[np.ndarray], camera_intrinsics: Optional[Dict] = None, mesh_path: Optional[str] = None ) -> bool: """Register an object for tracking. Args: object_id: Unique identifier for the object reference_images: List of RGB reference images (H, W, 3) camera_intrinsics: Camera parameters {fx, fy, cx, cy} mesh_path: Optional path to object mesh file Returns: True if registration successful """ try: # Load mesh if provided mesh = None if mesh_path and Path(mesh_path).exists(): if trimesh is None: logger.warning("trimesh not available, skipping mesh load") else: try: mesh = trimesh.load(mesh_path) logger.info(f"Loaded mesh for '{object_id}' from {mesh_path}") except Exception as e: logger.warning(f"Failed to load mesh: {e}") # Store object registration self.registered_objects[object_id] = { "num_references": len(reference_images), "camera_intrinsics": camera_intrinsics, "mesh_path": mesh_path, "mesh": mesh, "reference_images": reference_images, "estimator": None, # Will be created lazily "pose_last": None # Track last pose for temporal tracking } logger.info(f"✓ Registered object '{object_id}' with {len(reference_images)} reference images") return True except Exception as e: logger.error(f"Failed to register object '{object_id}': {e}", exc_info=True) return False def estimate_pose( self, object_id: str, rgb_image: np.ndarray, depth_image: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, camera_intrinsics: Optional[Dict] = None ) -> Optional[Dict]: """Estimate 6D pose of registered object in image. Args: object_id: ID of object to detect rgb_image: RGB query image (H, W, 3) depth_image: Optional depth image (H, W) mask: Optional object mask (H, W) camera_intrinsics: Camera parameters {fx, fy, cx, cy} Returns: Pose dictionary with position, orientation, confidence or None """ if object_id not in self.registered_objects: logger.error(f"Object '{object_id}' not registered") return None if not FOUNDATIONPOSE_AVAILABLE: logger.error("FoundationPose not available") return None try: obj_data = self.registered_objects[object_id] # Initialize predictors if not done yet if self.scorer is None: logger.info("Initializing score predictor...") self.scorer = ScorePredictor() logger.info("Initializing pose refiner...") self.refiner = PoseRefinePredictor() logger.info("Initializing CUDA rasterizer...") self.glctx = dr.RasterizeCudaContext() # Initialize object-specific estimator if not done yet if obj_data["estimator"] is None: logger.info(f"Creating FoundationPose estimator for '{object_id}'...") mesh = obj_data["mesh"] if mesh is not None: # Model-based mode: use mesh logger.info("Using model-based mode with mesh") obj_data["estimator"] = FoundationPose( model_pts=mesh.vertices, model_normals=mesh.vertex_normals, mesh=mesh, scorer=self.scorer, refiner=self.refiner, glctx=self.glctx, debug=0 ) else: # Model-free mode: requires 3D reconstruction from reference images # This would typically use structure-from-motion (SfM) to create a mesh # For now, this is not implemented logger.error("Model-free mode not yet implemented") logger.error("To use FoundationPose, please provide a 3D mesh (.obj, .stl, .ply)") logger.error("You can:") logger.error(" 1. Use CAD-Based initialization with your object's 3D model") logger.error(" 2. Create a mesh from photos using photogrammetry tools (e.g., Meshroom, COLMAP)") logger.error(" 3. Scan the object with a 3D scanner") return None estimator = obj_data["estimator"] # Prepare camera intrinsics matrix K = self._get_camera_matrix(camera_intrinsics or obj_data["camera_intrinsics"]) if K is None: logger.error("Camera intrinsics required") return None # Generate or use depth if not provided if depth_image is None: # Create dummy depth for model-based case # Use a more realistic depth distribution centered at 0.5m with some variation depth_image = np.ones((rgb_image.shape[0], rgb_image.shape[1]), dtype=np.float32) * 0.5 logger.warning("Using dummy depth image - for better results, provide actual depth data") # Generate mask if not provided mask_was_generated = False debug_mask = None if mask is None: # Use automatic foreground segmentation based on brightness # This works well for light objects on dark backgrounds logger.info("Generating automatic object mask from image") mask, debug_mask, mask_percentage, fallback_full_image = generate_naive_mask(rgb_image) logger.info(f"Auto-generated mask covers {mask_percentage:.1f}% of image") if fallback_full_image: logger.warning( f"Mask coverage ({mask_percentage:.1f}%) seems unrealistic, using full image" ) mask_was_generated = True # First frame or lost tracking: register if obj_data["pose_last"] is None: logger.info("Running registration (first frame)...") pose = estimator.register( K=K, rgb=rgb_image, depth=depth_image, ob_mask=mask, iteration=5 # Number of refinement iterations ) else: # Subsequent frames: track pose = estimator.track_one( rgb=rgb_image, depth=depth_image, K=K, iteration=2 # Fewer iterations for tracking ) # Store pose for next frame (move to CPU if it's a tensor) if torch.is_tensor(pose): pose = pose.detach().cpu().numpy() obj_data["pose_last"] = pose if pose is None: logger.warning("Pose estimation returned None") return None # Convert pose to our format # pose is a 4x4 transformation matrix result = self._format_pose_output(pose) # Add debug mask if it was auto-generated if mask_was_generated and debug_mask is not None: result["debug_mask"] = debug_mask return result except Exception as e: logger.error(f"Pose estimation failed: {e}", exc_info=True) import traceback traceback.print_exc() return None def _get_camera_matrix(self, intrinsics: Optional[Dict]) -> Optional[np.ndarray]: """Convert intrinsics dict to camera matrix.""" if intrinsics is None: return None fx = intrinsics.get("fx") fy = intrinsics.get("fy") cx = intrinsics.get("cx") cy = intrinsics.get("cy") if None in [fx, fy, cx, cy]: return None K = np.array([ [fx, 0, cx], [0, fy, cy], [0, 0, 1] ], dtype=np.float64) return K def _format_pose_output(self, pose_matrix: np.ndarray) -> Dict: """Convert 4x4 pose matrix to output format. Args: pose_matrix: 4x4 transformation matrix Returns: Dictionary with position, orientation (quaternion), and confidence """ if torch.is_tensor(pose_matrix): pose_matrix = pose_matrix.detach().cpu().numpy() # Extract translation translation = pose_matrix[:3, 3] # Extract rotation matrix rotation_matrix = pose_matrix[:3, :3] # Convert rotation matrix to quaternion # Using Shepperd's method for numerical stability trace = np.trace(rotation_matrix) if trace > 0: s = np.sqrt(trace + 1.0) * 2 w = 0.25 * s x = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s y = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s z = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s elif rotation_matrix[0, 0] > rotation_matrix[1, 1] and rotation_matrix[0, 0] > rotation_matrix[2, 2]: s = np.sqrt(1.0 + rotation_matrix[0, 0] - rotation_matrix[1, 1] - rotation_matrix[2, 2]) * 2 w = (rotation_matrix[2, 1] - rotation_matrix[1, 2]) / s x = 0.25 * s y = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s z = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s elif rotation_matrix[1, 1] > rotation_matrix[2, 2]: s = np.sqrt(1.0 + rotation_matrix[1, 1] - rotation_matrix[0, 0] - rotation_matrix[2, 2]) * 2 w = (rotation_matrix[0, 2] - rotation_matrix[2, 0]) / s x = (rotation_matrix[0, 1] + rotation_matrix[1, 0]) / s y = 0.25 * s z = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s else: s = np.sqrt(1.0 + rotation_matrix[2, 2] - rotation_matrix[0, 0] - rotation_matrix[1, 1]) * 2 w = (rotation_matrix[1, 0] - rotation_matrix[0, 1]) / s x = (rotation_matrix[0, 2] + rotation_matrix[2, 0]) / s y = (rotation_matrix[1, 2] + rotation_matrix[2, 1]) / s z = 0.25 * s return { "position": { "x": float(translation[0]), "y": float(translation[1]), "z": float(translation[2]) }, "orientation": { "w": float(w), "x": float(x), "y": float(y), "z": float(z) }, "confidence": 1.0, # FoundationPose doesn't provide explicit confidence "pose_matrix": pose_matrix.tolist() }