""" BA Validator: Uses Bundle Adjustment as an oracle teacher to validate model predictions. """ import logging import shutil from pathlib import Path from typing import Dict, List, Optional, Tuple import h5py import numpy as np try: from ..utils.profiler import profile, profile_context HAS_PROFILER = True except ImportError: HAS_PROFILER = False def profile(*args, **kwargs): def decorator(func): return func return decorator class profile_context: def __init__(self, *args, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): return False try: import pycolmap from hloc import extract_features, match_features HAS_BA_DEPS = True except ImportError: logging.warning("hloc or pycolmap not installed. BA validation will not work.") pycolmap = None HAS_BA_DEPS = False logger = logging.getLogger(__name__) class BAValidator: """ Validates model predictions using Bundle Adjustment. Uses BA as an oracle teacher to identify model failures and generate pseudo-labels for fine-tuning. """ def __init__( self, accept_threshold: float = 2.0, # degrees reject_threshold: float = 30.0, # degrees feature_conf: str = "superpoint_max", matcher_conf: str = "superpoint+lightglue", work_dir: Optional[Path] = None, match_num_workers: int = 5, # Number of workers for parallel pair loading ): """ Args: accept_threshold: Maximum rotation error (degrees) to accept model prediction reject_threshold: Maximum rotation error (degrees) before considering outlier feature_conf: Feature extraction config (superpoint_max, etc.) matcher_conf: Matcher config (lightglue, superglue, etc.) work_dir: Working directory for temporary files match_num_workers: Number of workers for parallel pair loading (default: 5) """ self.accept_threshold = accept_threshold self.reject_threshold = reject_threshold self.feature_conf = feature_conf self.matcher_conf = matcher_conf self.work_dir = work_dir or Path("/tmp/ylff_ba") self.work_dir.mkdir(parents=True, exist_ok=True) self.match_num_workers = match_num_workers # Feature cache directory self.feature_cache_dir = self.work_dir / "feature_cache" self.feature_cache_dir.mkdir(exist_ok=True) if not HAS_BA_DEPS: raise ImportError( "pycolmap and hloc are required for BA validation. " "Install with: pip install pycolmap git+https://github.com/cvg/Hierarchical-Localization.git" ) def validate( self, images: List[np.ndarray], poses_model: np.ndarray, intrinsics: Optional[np.ndarray] = None, ) -> Dict: """ Validate model poses using Bundle Adjustment. Args: images: List of RGB images (H, W, 3) uint8 poses_model: Model-predicted poses (N, 3, 4) or (N, 4, 4) intrinsics: Camera intrinsics (N, 3, 3), optional Returns: Dictionary with validation results: - status: 'accepted', 'rejected_learnable', 'rejected_outlier', or 'ba_failed' - error: Maximum rotation error in degrees - poses_ba: BA-refined poses (if successful) - reprojection_error: Average reprojection error """ N = len(images) # Convert poses to 4x4 if needed if poses_model.shape[1] == 3: poses_4x4 = np.eye(4, dtype=poses_model.dtype)[None, :, :].repeat(N, axis=0) poses_4x4[:, :3, :] = poses_model poses_model = poses_4x4 # Save images temporarily image_dir = self.work_dir / "images" image_dir.mkdir(exist_ok=True) image_paths = [] for i, img in enumerate(images): path = image_dir / f"frame_{i:06d}.jpg" import cv2 cv2.imwrite(str(path), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) image_paths.append(str(path)) try: # 1. Extract features features = self._extract_features(image_paths) # 2. Match features (with smart pairing if poses available) matches = self._match_features( image_paths, features, poses=poses_model, smart_pairing=True, # Enable smart pairing by default ) # 3. Run COLMAP BA, initialized from model poses ba_result = self._run_colmap_ba( image_paths=image_paths, features=features, matches=matches, initial_poses=poses_model, intrinsics=intrinsics, ) if not ba_result["success"]: return { "status": "ba_failed", "error": None, "poses_ba": None, "reprojection_error": None, } poses_ba = ba_result["poses"] reproj_error = ba_result["reprojection_error"] # 4. Compare poses error_metrics = self._compute_pose_error(poses_model, poses_ba) max_rot_error = error_metrics["max_rotation_error_deg"] # 5. Categorize if max_rot_error < self.accept_threshold: return { "status": "accepted", "error": max_rot_error, "poses_ba": poses_ba, "reprojection_error": reproj_error, "error_metrics": error_metrics, } elif max_rot_error < self.reject_threshold: return { "status": "rejected_learnable", "error": max_rot_error, "poses_ba": poses_ba, # Pseudo-label! "reprojection_error": reproj_error, "error_metrics": error_metrics, } else: return { "status": "rejected_outlier", "error": max_rot_error, "poses_ba": poses_ba, "reprojection_error": reproj_error, "error_metrics": error_metrics, } except Exception as e: logger.error(f"BA validation failed: {e}") return { "status": "ba_failed", "error": str(e), "poses_ba": None, "reprojection_error": None, } def _get_image_hash(self, image_path: str) -> str: """Generate hash from image file for caching.""" import hashlib with open(image_path, "rb") as f: img_hash = hashlib.md5(f.read()).hexdigest() return img_hash def _get_cache_key(self, image_path: str) -> str: """Generate cache key from image path and feature config.""" img_hash = self._get_image_hash(image_path) return f"{self.feature_conf}_{img_hash}" @profile(stage="gpu", operation="feature_extraction") def _extract_features(self, image_paths: List[str], use_cache: bool = True) -> Path: """ Extract features using hloc with optional caching. Args: image_paths: List of image file paths use_cache: If True, use cached features when available Returns: Path to features HDF5 file """ feature_path = self.work_dir / "features.h5" if use_cache: # Check cache for existing features cached_features = {} uncached_paths = [] logger.info(f"Checking feature cache for {len(image_paths)} images...") # Load cached features cache_hits = 0 for img_path in image_paths: cache_key = self._get_cache_key(img_path) cache_file = self.feature_cache_dir / f"{cache_key}.h5" if cache_file.exists(): try: # Copy cached features to main feature file with h5py.File(cache_file, "r") as cache_f: with h5py.File(feature_path, "a") as main_f: img_name = Path(img_path).name if img_name not in main_f: # Copy the cached group cache_f.copy(img_name, main_f) cached_features[img_path] = cache_key cache_hits += 1 except Exception as e: logger.warning(f"Failed to load cached features for {img_path}: {e}") uncached_paths.append(img_path) else: uncached_paths.append(img_path) if cache_hits > 0: logger.info(f" ✓ Cache hits: {cache_hits}/{len(image_paths)} images") if len(uncached_paths) == 0: logger.info(f"✓ All features loaded from cache: {feature_path}") return feature_path logger.info(f" Extracting features for {len(uncached_paths)} uncached images...") # Extract features for uncached images # Create temporary directory with only uncached images temp_image_dir = self.work_dir / "temp_images" temp_image_dir.mkdir(exist_ok=True) for img_path in uncached_paths: img_name = Path(img_path).name temp_path = temp_image_dir / img_name shutil.copy2(img_path, temp_path) # Extract features for uncached images temp_feature_path = self.work_dir / "temp_features.h5" extract_features.main( conf=extract_features.confs[self.feature_conf], image_dir=temp_image_dir, feature_path=temp_feature_path, ) # Merge temp features into main feature file and cache with h5py.File(temp_feature_path, "r") as temp_f: with h5py.File(feature_path, "a") as main_f: for img_path in uncached_paths: img_name = Path(img_path).name if img_name in temp_f: # Copy to main file if img_name in main_f: del main_f[img_name] temp_f.copy(img_name, main_f) # Save to cache cache_key = self._get_cache_key(img_path) cache_file = self.feature_cache_dir / f"{cache_key}.h5" with h5py.File(cache_file, "w") as cache_f: temp_f.copy(img_name, cache_f) # Cleanup temp files temp_feature_path.unlink(missing_ok=True) shutil.rmtree(temp_image_dir, ignore_errors=True) logger.info(f"✓ Features extracted and cached: {feature_path}") logger.info(f" - Cached: {cache_hits}, Extracted: {len(uncached_paths)}") else: # No caching - extract all features logger.info(f"Extracting features from {len(image_paths)} images...") logger.info(f" Using feature extractor: {self.feature_conf}") extract_features.main( conf=extract_features.confs[self.feature_conf], image_dir=Path(image_paths[0]).parent, feature_path=feature_path, ) logger.info(f"✓ Features extracted: {feature_path}") return feature_path def _generate_smart_pairs( self, image_paths: List[str], poses: Optional[np.ndarray] = None, max_baseline: Optional[float] = None, min_baseline: float = 0.05, sequential_only: bool = False, max_pairs_per_image: int = 10, ) -> List[Tuple[str, str]]: """ Generate smart pairs based on spatial proximity or sequential ordering. Args: image_paths: List of image paths poses: Optional poses (N, 3, 4) to compute baselines max_baseline: Maximum translation distance (if None, use sequential) min_baseline: Minimum translation distance sequential_only: If True, only match consecutive frames max_pairs_per_image: Maximum number of pairs per image Returns: List of (image1, image2) pairs """ pairs = [] if sequential_only: # Only match consecutive frames (N-1 pairs) for i in range(len(image_paths) - 1): pairs.append((Path(image_paths[i]).name, Path(image_paths[i + 1]).name)) logger.info(f"Generated {len(pairs)} sequential pairs") return pairs if poses is not None and max_baseline is not None: # Spatial selection based on poses for i in range(len(image_paths)): image_pairs = [] t_i = poses[i][:3, 3] for j in range(i + 1, len(image_paths)): t_j = poses[j][:3, 3] baseline = np.linalg.norm(t_i - t_j) if min_baseline <= baseline <= max_baseline: image_pairs.append((baseline, j)) # Sort by baseline and take closest max_pairs_per_image image_pairs.sort(key=lambda x: x[0]) for _, j in image_pairs[:max_pairs_per_image]: pairs.append((Path(image_paths[i]).name, Path(image_paths[j]).name)) logger.info( f"Generated {len(pairs)} spatial pairs " f"(baseline: {min_baseline:.2f}-{max_baseline:.2f})" ) return pairs # Fallback: exhaustive matching (original behavior) for i in range(len(image_paths)): for j in range(i + 1, len(image_paths)): pairs.append((Path(image_paths[i]).name, Path(image_paths[j]).name)) logger.info(f"Generated {len(pairs)} exhaustive pairs") return pairs @profile(stage="gpu", operation="feature_matching") def _match_features( self, image_paths: List[str], features: Path, poses: Optional[np.ndarray] = None, smart_pairing: bool = True, ) -> Path: """ Match features using hloc. Args: image_paths: List of image paths features: Path to features file poses: Optional poses for smart pairing smart_pairing: If True, use smart pair selection """ pairs_path = self.work_dir / "pairs.txt" matches_path = self.work_dir / "matches.h5" # Generate pairs if smart_pairing and poses is not None: # Use smart pairing with spatial selection pairs = self._generate_smart_pairs( image_paths, poses=poses, max_baseline=0.5, # Reasonable baseline for video min_baseline=0.05, max_pairs_per_image=10, ) elif smart_pairing: # Use sequential pairing (no poses needed) pairs = self._generate_smart_pairs( image_paths, sequential_only=True, ) else: # Exhaustive matching pairs = self._generate_smart_pairs(image_paths) num_pairs = len(pairs) logger.info(f"Generating {num_pairs} image pairs for matching...") # Write pairs file with open(pairs_path, "w") as f: for img1, img2 in pairs: f.write(f"{img1} {img2}\n") logger.info(f"✓ Pairs file created: {pairs_path}") logger.info(f"Matching features using {self.matcher_conf}...") try: match_conf = match_features.confs[self.matcher_conf] except KeyError: available = list(match_features.confs.keys()) logger.error( f"Matcher config '{self.matcher_conf}' not found. " f"Available: {available}" ) raise match_features.main( conf=match_conf, pairs=pairs_path, features=features, matches=matches_path, ) logger.info(f"✓ Features matched: {matches_path}") return matches_path @profile(stage="cpu", operation="colmap_ba") def _run_colmap_ba( self, image_paths: List[str], features: Path, matches: Path, initial_poses: np.ndarray, intrinsics: Optional[np.ndarray] = None, ) -> Dict: """ Run COLMAP Bundle Adjustment using hloc's reconstruction pipeline. Uses hloc.reconstruction.main to: 1. Create COLMAP database from features and matches 2. Run incremental SfM with bundle adjustment 3. Extract refined poses Returns: Dictionary with 'success', 'poses', 'reprojection_error' """ try: from hloc import reconstruction except ImportError: logger.warning("hloc reconstruction module not available. Using simplified BA.") return self._run_simplified_ba(image_paths, initial_poses, intrinsics) sfm_dir = self.work_dir / "sfm" sfm_dir.mkdir(exist_ok=True) image_dir = Path(image_paths[0]).parent # Create pairs file (all pairs for exhaustive matching) pairs_path = self.work_dir / "pairs.txt" if not pairs_path.exists(): with open(pairs_path, "w") as f: for i in range(len(image_paths)): for j in range(i + 1, len(image_paths)): f.write(f"{Path(image_paths[i]).name} {Path(image_paths[j]).name}\n") # Determine camera mode if intrinsics is not None: # Check if all intrinsics are the same first_K = intrinsics[0] all_same = all(np.allclose(K, first_K) for K in intrinsics) camera_mode = ( pycolmap.CameraMode.SINGLE_CAMERA if all_same else pycolmap.CameraMode.PER_IMAGE ) else: camera_mode = pycolmap.CameraMode.SINGLE_CAMERA logger.info(f"Running COLMAP reconstruction with camera_mode={camera_mode}...") try: # Run hloc's reconstruction pipeline # This will create a database, import features/matches, and run incremental SfM with BA ba_reconstruction = reconstruction.main( sfm_dir=sfm_dir, image_dir=image_dir, pairs=pairs_path, features=features, matches=matches, camera_mode=camera_mode, verbose=False, ) # reconstruction.main returns the Reconstruction object directly # But it may also write to disk - check both if ba_reconstruction is None: # Try loading from disk # hloc may write to sfm_dir directly or to a subdirectory if (sfm_dir / "images.bin").exists(): ba_reconstruction = pycolmap.Reconstruction(str(sfm_dir)) elif (sfm_dir / "0" / "images.bin").exists(): ba_reconstruction = pycolmap.Reconstruction(str(sfm_dir / "0")) else: # Check for models subdirectory models_dir = sfm_dir / "models" if models_dir.exists(): model_dirs = [ d for d in models_dir.iterdir() if d.is_dir() and (d / "images.bin").exists() ] if model_dirs: ba_reconstruction = pycolmap.Reconstruction(str(model_dirs[0])) if ba_reconstruction is None or len(ba_reconstruction.images) == 0: logger.warning("COLMAP reconstruction failed or produced no images.") return { "success": False, "error_message": "Reconstruction produced no images", "poses": None, "reprojection_error": None, } # Extract poses from reconstruction ba_poses = [] reprojection_errors = [] # Map image names to indices image_name_to_idx = {Path(p).name: i for i, p in enumerate(image_paths)} for img_id in sorted(ba_reconstruction.images.keys()): img = ba_reconstruction.images[img_id] if not img.has_pose: logger.warning(f"Image {img.name} has no pose") continue # COLMAP stores camera-to-world pose # cam_from_world() returns a Rigid3d object try: pose = img.cam_from_world() # Rigid3d has rotation and translation R = pose.rotation.matrix() # 3x3 rotation matrix t = pose.translation # 3x1 translation vector # Construct 4x4 c2w matrix c2w = np.eye(4) c2w[:3, :3] = R c2w[:3, 3] = t w2c = np.linalg.inv(c2w) ba_poses.append(w2c[:3, :]) # Extract 3x4 w2c matrix # Get reprojection error for this image # mean_reprojection_error is a method try: reproj_error = ( img.mean_reprojection_error() if callable(img.mean_reprojection_error) else 0.0 ) except Exception: reproj_error = 0.0 reprojection_errors.append(reproj_error) except Exception as e: logger.warning(f"Failed to extract pose for image {img.name}: {e}") continue # Align BA poses to match the order of input images # COLMAP may not reconstruct all images, so we need to match by name ordered_ba_poses = [] for img_path in image_paths: img_name = Path(img_path).name found = False for img_id in sorted(ba_reconstruction.images.keys()): img = ba_reconstruction.images[img_id] if img.name == img_name: if not img.has_pose: logger.warning(f"Image {img_name} has no pose in BA reconstruction") break try: # Extract pose same way as above pose = img.cam_from_world() R = pose.rotation.matrix() t = pose.translation c2w = np.eye(4) c2w[:3, :3] = R c2w[:3, 3] = t w2c = np.linalg.inv(c2w) ordered_ba_poses.append(w2c[:3, :]) # Extract 3x4 w2c matrix found = True except Exception as e: logger.warning(f"Failed to extract pose for {img_name}: {e}") break if not found: logger.warning( f"Image {img_name} not found in BA reconstruction. Using initial pose." ) # Use initial pose if BA didn't reconstruct this image idx = image_name_to_idx[img_name] ordered_ba_poses.append(initial_poses[idx]) if not ordered_ba_poses: return { "success": False, "error_message": "No poses extracted from reconstruction", "poses": None, "reprojection_error": None, } # Ensure all poses are 3x4 ordered_ba_poses_3x4 = [] for pose in ordered_ba_poses: if pose.shape == (3, 4): ordered_ba_poses_3x4.append(pose) elif pose.shape == (4, 4): ordered_ba_poses_3x4.append(pose[:3, :]) else: logger.warning(f"Unexpected pose shape: {pose.shape}, skipping") # Use identity as fallback pose_3x4 = np.eye(3, 4) ordered_ba_poses_3x4.append(pose_3x4) return { "success": True, "poses": np.array(ordered_ba_poses_3x4), "reprojection_error": ( np.mean(reprojection_errors) if reprojection_errors else None ), } except Exception as e: logger.error(f"COLMAP reconstruction failed: {e}") import traceback logger.debug(traceback.format_exc()) return { "success": False, "error_message": str(e), "poses": None, "reprojection_error": None, } def _pose_3x4_to_4x4(self, pose: np.ndarray) -> np.ndarray: """Convert 3x4 pose to 4x4 homogeneous matrix.""" if pose.shape == (4, 4): return pose pose_4x4 = np.eye(4, dtype=pose.dtype) pose_4x4[:3, :] = pose return pose_4x4 def _run_simplified_ba( self, image_paths: List[str], initial_poses: np.ndarray, intrinsics: Optional[np.ndarray] = None, ) -> Dict: """Simplified BA that just returns initial poses (for testing).""" logger.warning( "Using simplified BA (no actual optimization). Full BA requires triangulation." ) return { "success": True, "poses": initial_poses, "reprojection_error": 0.0, } def _compute_pose_error( self, poses1: np.ndarray, poses2: np.ndarray, ) -> Dict: """ Compute pose error between two sets of poses. Returns: Dictionary with error metrics """ # Align poses (Procrustes alignment) poses1_aligned = self._align_trajectories(poses1, poses2) rotation_errors = [] translation_errors = [] for i in range(len(poses1)): R1 = poses1_aligned[i][:3, :3] R2 = poses2[i][:3, :3] t1 = poses1_aligned[i][:3, 3] t2 = poses2[i][:3, 3] # Rotation error: geodesic distance R_diff = R1 @ R2.T trace = np.trace(R_diff) angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1)) angle_deg = np.degrees(angle_rad) rotation_errors.append(angle_deg) # Translation error trans_error = np.linalg.norm(t1 - t2) translation_errors.append(trans_error) # Compute scene scale for relative translation error scene_scale = np.percentile(translation_errors, 75) if translation_errors else 1.0 return { "rotation_errors_deg": rotation_errors, "translation_errors": translation_errors, "max_rotation_error_deg": np.max(rotation_errors), "mean_rotation_error_deg": np.mean(rotation_errors), "max_translation_error": np.max(translation_errors), "mean_translation_error": np.mean(translation_errors), "scene_scale": scene_scale, } def _align_trajectories( self, poses1: np.ndarray, poses2: np.ndarray, ) -> np.ndarray: """ Align trajectory 1 to trajectory 2 using Procrustes alignment. """ # Extract centers centers1 = poses1[:, :3, 3] centers2 = poses2[:, :3, 3] # Center both trajectories center1_mean = centers1.mean(axis=0) center2_mean = centers2.mean(axis=0) centers1_centered = centers1 - center1_mean centers2_centered = centers2 - center2_mean # Compute scale scale1 = np.linalg.norm(centers1_centered, axis=1).mean() scale2 = np.linalg.norm(centers2_centered, axis=1).mean() scale = scale2 / (scale1 + 1e-8) # Compute rotation (SVD) H = centers1_centered.T @ centers2_centered U, _, Vt = np.linalg.svd(H) R_align = Vt.T @ U.T # Apply alignment poses1_aligned = poses1.copy() for i in range(len(poses1)): # Align rotation R_orig = poses1[i][:3, :3] R_aligned = R_align @ R_orig poses1_aligned[i][:3, :3] = R_aligned # Align translation t_orig = poses1[i][:3, 3] t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean poses1_aligned[i][:3, 3] = t_aligned return poses1_aligned