3d_model / ylff /services /ba_validator.py
Azan
Fix: Use git URL for hloc installation
6352295
"""
BA Validator: Uses Bundle Adjustment as an oracle teacher to validate model predictions.
"""
import logging
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import h5py
import numpy as np
try:
from ..utils.profiler import profile, profile_context
HAS_PROFILER = True
except ImportError:
HAS_PROFILER = False
def profile(*args, **kwargs):
def decorator(func):
return func
return decorator
class profile_context:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
return False
try:
import pycolmap
from hloc import extract_features, match_features
HAS_BA_DEPS = True
except ImportError:
logging.warning("hloc or pycolmap not installed. BA validation will not work.")
pycolmap = None
HAS_BA_DEPS = False
logger = logging.getLogger(__name__)
class BAValidator:
"""
Validates model predictions using Bundle Adjustment.
Uses BA as an oracle teacher to identify model failures and generate
pseudo-labels for fine-tuning.
"""
def __init__(
self,
accept_threshold: float = 2.0, # degrees
reject_threshold: float = 30.0, # degrees
feature_conf: str = "superpoint_max",
matcher_conf: str = "superpoint+lightglue",
work_dir: Optional[Path] = None,
match_num_workers: int = 5, # Number of workers for parallel pair loading
):
"""
Args:
accept_threshold: Maximum rotation error (degrees) to accept model prediction
reject_threshold: Maximum rotation error (degrees) before considering outlier
feature_conf: Feature extraction config (superpoint_max, etc.)
matcher_conf: Matcher config (lightglue, superglue, etc.)
work_dir: Working directory for temporary files
match_num_workers: Number of workers for parallel pair loading (default: 5)
"""
self.accept_threshold = accept_threshold
self.reject_threshold = reject_threshold
self.feature_conf = feature_conf
self.matcher_conf = matcher_conf
self.work_dir = work_dir or Path("/tmp/ylff_ba")
self.work_dir.mkdir(parents=True, exist_ok=True)
self.match_num_workers = match_num_workers
# Feature cache directory
self.feature_cache_dir = self.work_dir / "feature_cache"
self.feature_cache_dir.mkdir(exist_ok=True)
if not HAS_BA_DEPS:
raise ImportError(
"pycolmap and hloc are required for BA validation. "
"Install with: pip install pycolmap git+https://github.com/cvg/Hierarchical-Localization.git"
)
def validate(
self,
images: List[np.ndarray],
poses_model: np.ndarray,
intrinsics: Optional[np.ndarray] = None,
) -> Dict:
"""
Validate model poses using Bundle Adjustment.
Args:
images: List of RGB images (H, W, 3) uint8
poses_model: Model-predicted poses (N, 3, 4) or (N, 4, 4)
intrinsics: Camera intrinsics (N, 3, 3), optional
Returns:
Dictionary with validation results:
- status: 'accepted', 'rejected_learnable', 'rejected_outlier', or 'ba_failed'
- error: Maximum rotation error in degrees
- poses_ba: BA-refined poses (if successful)
- reprojection_error: Average reprojection error
"""
N = len(images)
# Convert poses to 4x4 if needed
if poses_model.shape[1] == 3:
poses_4x4 = np.eye(4, dtype=poses_model.dtype)[None, :, :].repeat(N, axis=0)
poses_4x4[:, :3, :] = poses_model
poses_model = poses_4x4
# Save images temporarily
image_dir = self.work_dir / "images"
image_dir.mkdir(exist_ok=True)
image_paths = []
for i, img in enumerate(images):
path = image_dir / f"frame_{i:06d}.jpg"
import cv2
cv2.imwrite(str(path), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
image_paths.append(str(path))
try:
# 1. Extract features
features = self._extract_features(image_paths)
# 2. Match features (with smart pairing if poses available)
matches = self._match_features(
image_paths,
features,
poses=poses_model,
smart_pairing=True, # Enable smart pairing by default
)
# 3. Run COLMAP BA, initialized from model poses
ba_result = self._run_colmap_ba(
image_paths=image_paths,
features=features,
matches=matches,
initial_poses=poses_model,
intrinsics=intrinsics,
)
if not ba_result["success"]:
return {
"status": "ba_failed",
"error": None,
"poses_ba": None,
"reprojection_error": None,
}
poses_ba = ba_result["poses"]
reproj_error = ba_result["reprojection_error"]
# 4. Compare poses
error_metrics = self._compute_pose_error(poses_model, poses_ba)
max_rot_error = error_metrics["max_rotation_error_deg"]
# 5. Categorize
if max_rot_error < self.accept_threshold:
return {
"status": "accepted",
"error": max_rot_error,
"poses_ba": poses_ba,
"reprojection_error": reproj_error,
"error_metrics": error_metrics,
}
elif max_rot_error < self.reject_threshold:
return {
"status": "rejected_learnable",
"error": max_rot_error,
"poses_ba": poses_ba, # Pseudo-label!
"reprojection_error": reproj_error,
"error_metrics": error_metrics,
}
else:
return {
"status": "rejected_outlier",
"error": max_rot_error,
"poses_ba": poses_ba,
"reprojection_error": reproj_error,
"error_metrics": error_metrics,
}
except Exception as e:
logger.error(f"BA validation failed: {e}")
return {
"status": "ba_failed",
"error": str(e),
"poses_ba": None,
"reprojection_error": None,
}
def _get_image_hash(self, image_path: str) -> str:
"""Generate hash from image file for caching."""
import hashlib
with open(image_path, "rb") as f:
img_hash = hashlib.md5(f.read()).hexdigest()
return img_hash
def _get_cache_key(self, image_path: str) -> str:
"""Generate cache key from image path and feature config."""
img_hash = self._get_image_hash(image_path)
return f"{self.feature_conf}_{img_hash}"
@profile(stage="gpu", operation="feature_extraction")
def _extract_features(self, image_paths: List[str], use_cache: bool = True) -> Path:
"""
Extract features using hloc with optional caching.
Args:
image_paths: List of image file paths
use_cache: If True, use cached features when available
Returns:
Path to features HDF5 file
"""
feature_path = self.work_dir / "features.h5"
if use_cache:
# Check cache for existing features
cached_features = {}
uncached_paths = []
logger.info(f"Checking feature cache for {len(image_paths)} images...")
# Load cached features
cache_hits = 0
for img_path in image_paths:
cache_key = self._get_cache_key(img_path)
cache_file = self.feature_cache_dir / f"{cache_key}.h5"
if cache_file.exists():
try:
# Copy cached features to main feature file
with h5py.File(cache_file, "r") as cache_f:
with h5py.File(feature_path, "a") as main_f:
img_name = Path(img_path).name
if img_name not in main_f:
# Copy the cached group
cache_f.copy(img_name, main_f)
cached_features[img_path] = cache_key
cache_hits += 1
except Exception as e:
logger.warning(f"Failed to load cached features for {img_path}: {e}")
uncached_paths.append(img_path)
else:
uncached_paths.append(img_path)
if cache_hits > 0:
logger.info(f" ✓ Cache hits: {cache_hits}/{len(image_paths)} images")
if len(uncached_paths) == 0:
logger.info(f"✓ All features loaded from cache: {feature_path}")
return feature_path
logger.info(f" Extracting features for {len(uncached_paths)} uncached images...")
# Extract features for uncached images
# Create temporary directory with only uncached images
temp_image_dir = self.work_dir / "temp_images"
temp_image_dir.mkdir(exist_ok=True)
for img_path in uncached_paths:
img_name = Path(img_path).name
temp_path = temp_image_dir / img_name
shutil.copy2(img_path, temp_path)
# Extract features for uncached images
temp_feature_path = self.work_dir / "temp_features.h5"
extract_features.main(
conf=extract_features.confs[self.feature_conf],
image_dir=temp_image_dir,
feature_path=temp_feature_path,
)
# Merge temp features into main feature file and cache
with h5py.File(temp_feature_path, "r") as temp_f:
with h5py.File(feature_path, "a") as main_f:
for img_path in uncached_paths:
img_name = Path(img_path).name
if img_name in temp_f:
# Copy to main file
if img_name in main_f:
del main_f[img_name]
temp_f.copy(img_name, main_f)
# Save to cache
cache_key = self._get_cache_key(img_path)
cache_file = self.feature_cache_dir / f"{cache_key}.h5"
with h5py.File(cache_file, "w") as cache_f:
temp_f.copy(img_name, cache_f)
# Cleanup temp files
temp_feature_path.unlink(missing_ok=True)
shutil.rmtree(temp_image_dir, ignore_errors=True)
logger.info(f"✓ Features extracted and cached: {feature_path}")
logger.info(f" - Cached: {cache_hits}, Extracted: {len(uncached_paths)}")
else:
# No caching - extract all features
logger.info(f"Extracting features from {len(image_paths)} images...")
logger.info(f" Using feature extractor: {self.feature_conf}")
extract_features.main(
conf=extract_features.confs[self.feature_conf],
image_dir=Path(image_paths[0]).parent,
feature_path=feature_path,
)
logger.info(f"✓ Features extracted: {feature_path}")
return feature_path
def _generate_smart_pairs(
self,
image_paths: List[str],
poses: Optional[np.ndarray] = None,
max_baseline: Optional[float] = None,
min_baseline: float = 0.05,
sequential_only: bool = False,
max_pairs_per_image: int = 10,
) -> List[Tuple[str, str]]:
"""
Generate smart pairs based on spatial proximity or sequential ordering.
Args:
image_paths: List of image paths
poses: Optional poses (N, 3, 4) to compute baselines
max_baseline: Maximum translation distance (if None, use sequential)
min_baseline: Minimum translation distance
sequential_only: If True, only match consecutive frames
max_pairs_per_image: Maximum number of pairs per image
Returns:
List of (image1, image2) pairs
"""
pairs = []
if sequential_only:
# Only match consecutive frames (N-1 pairs)
for i in range(len(image_paths) - 1):
pairs.append((Path(image_paths[i]).name, Path(image_paths[i + 1]).name))
logger.info(f"Generated {len(pairs)} sequential pairs")
return pairs
if poses is not None and max_baseline is not None:
# Spatial selection based on poses
for i in range(len(image_paths)):
image_pairs = []
t_i = poses[i][:3, 3]
for j in range(i + 1, len(image_paths)):
t_j = poses[j][:3, 3]
baseline = np.linalg.norm(t_i - t_j)
if min_baseline <= baseline <= max_baseline:
image_pairs.append((baseline, j))
# Sort by baseline and take closest max_pairs_per_image
image_pairs.sort(key=lambda x: x[0])
for _, j in image_pairs[:max_pairs_per_image]:
pairs.append((Path(image_paths[i]).name, Path(image_paths[j]).name))
logger.info(
f"Generated {len(pairs)} spatial pairs "
f"(baseline: {min_baseline:.2f}-{max_baseline:.2f})"
)
return pairs
# Fallback: exhaustive matching (original behavior)
for i in range(len(image_paths)):
for j in range(i + 1, len(image_paths)):
pairs.append((Path(image_paths[i]).name, Path(image_paths[j]).name))
logger.info(f"Generated {len(pairs)} exhaustive pairs")
return pairs
@profile(stage="gpu", operation="feature_matching")
def _match_features(
self,
image_paths: List[str],
features: Path,
poses: Optional[np.ndarray] = None,
smart_pairing: bool = True,
) -> Path:
"""
Match features using hloc.
Args:
image_paths: List of image paths
features: Path to features file
poses: Optional poses for smart pairing
smart_pairing: If True, use smart pair selection
"""
pairs_path = self.work_dir / "pairs.txt"
matches_path = self.work_dir / "matches.h5"
# Generate pairs
if smart_pairing and poses is not None:
# Use smart pairing with spatial selection
pairs = self._generate_smart_pairs(
image_paths,
poses=poses,
max_baseline=0.5, # Reasonable baseline for video
min_baseline=0.05,
max_pairs_per_image=10,
)
elif smart_pairing:
# Use sequential pairing (no poses needed)
pairs = self._generate_smart_pairs(
image_paths,
sequential_only=True,
)
else:
# Exhaustive matching
pairs = self._generate_smart_pairs(image_paths)
num_pairs = len(pairs)
logger.info(f"Generating {num_pairs} image pairs for matching...")
# Write pairs file
with open(pairs_path, "w") as f:
for img1, img2 in pairs:
f.write(f"{img1} {img2}\n")
logger.info(f"✓ Pairs file created: {pairs_path}")
logger.info(f"Matching features using {self.matcher_conf}...")
try:
match_conf = match_features.confs[self.matcher_conf]
except KeyError:
available = list(match_features.confs.keys())
logger.error(
f"Matcher config '{self.matcher_conf}' not found. " f"Available: {available}"
)
raise
match_features.main(
conf=match_conf,
pairs=pairs_path,
features=features,
matches=matches_path,
)
logger.info(f"✓ Features matched: {matches_path}")
return matches_path
@profile(stage="cpu", operation="colmap_ba")
def _run_colmap_ba(
self,
image_paths: List[str],
features: Path,
matches: Path,
initial_poses: np.ndarray,
intrinsics: Optional[np.ndarray] = None,
) -> Dict:
"""
Run COLMAP Bundle Adjustment using hloc's reconstruction pipeline.
Uses hloc.reconstruction.main to:
1. Create COLMAP database from features and matches
2. Run incremental SfM with bundle adjustment
3. Extract refined poses
Returns:
Dictionary with 'success', 'poses', 'reprojection_error'
"""
try:
from hloc import reconstruction
except ImportError:
logger.warning("hloc reconstruction module not available. Using simplified BA.")
return self._run_simplified_ba(image_paths, initial_poses, intrinsics)
sfm_dir = self.work_dir / "sfm"
sfm_dir.mkdir(exist_ok=True)
image_dir = Path(image_paths[0]).parent
# Create pairs file (all pairs for exhaustive matching)
pairs_path = self.work_dir / "pairs.txt"
if not pairs_path.exists():
with open(pairs_path, "w") as f:
for i in range(len(image_paths)):
for j in range(i + 1, len(image_paths)):
f.write(f"{Path(image_paths[i]).name} {Path(image_paths[j]).name}\n")
# Determine camera mode
if intrinsics is not None:
# Check if all intrinsics are the same
first_K = intrinsics[0]
all_same = all(np.allclose(K, first_K) for K in intrinsics)
camera_mode = (
pycolmap.CameraMode.SINGLE_CAMERA if all_same else pycolmap.CameraMode.PER_IMAGE
)
else:
camera_mode = pycolmap.CameraMode.SINGLE_CAMERA
logger.info(f"Running COLMAP reconstruction with camera_mode={camera_mode}...")
try:
# Run hloc's reconstruction pipeline
# This will create a database, import features/matches, and run incremental SfM with BA
ba_reconstruction = reconstruction.main(
sfm_dir=sfm_dir,
image_dir=image_dir,
pairs=pairs_path,
features=features,
matches=matches,
camera_mode=camera_mode,
verbose=False,
)
# reconstruction.main returns the Reconstruction object directly
# But it may also write to disk - check both
if ba_reconstruction is None:
# Try loading from disk
# hloc may write to sfm_dir directly or to a subdirectory
if (sfm_dir / "images.bin").exists():
ba_reconstruction = pycolmap.Reconstruction(str(sfm_dir))
elif (sfm_dir / "0" / "images.bin").exists():
ba_reconstruction = pycolmap.Reconstruction(str(sfm_dir / "0"))
else:
# Check for models subdirectory
models_dir = sfm_dir / "models"
if models_dir.exists():
model_dirs = [
d
for d in models_dir.iterdir()
if d.is_dir() and (d / "images.bin").exists()
]
if model_dirs:
ba_reconstruction = pycolmap.Reconstruction(str(model_dirs[0]))
if ba_reconstruction is None or len(ba_reconstruction.images) == 0:
logger.warning("COLMAP reconstruction failed or produced no images.")
return {
"success": False,
"error_message": "Reconstruction produced no images",
"poses": None,
"reprojection_error": None,
}
# Extract poses from reconstruction
ba_poses = []
reprojection_errors = []
# Map image names to indices
image_name_to_idx = {Path(p).name: i for i, p in enumerate(image_paths)}
for img_id in sorted(ba_reconstruction.images.keys()):
img = ba_reconstruction.images[img_id]
if not img.has_pose:
logger.warning(f"Image {img.name} has no pose")
continue
# COLMAP stores camera-to-world pose
# cam_from_world() returns a Rigid3d object
try:
pose = img.cam_from_world()
# Rigid3d has rotation and translation
R = pose.rotation.matrix() # 3x3 rotation matrix
t = pose.translation # 3x1 translation vector
# Construct 4x4 c2w matrix
c2w = np.eye(4)
c2w[:3, :3] = R
c2w[:3, 3] = t
w2c = np.linalg.inv(c2w)
ba_poses.append(w2c[:3, :]) # Extract 3x4 w2c matrix
# Get reprojection error for this image
# mean_reprojection_error is a method
try:
reproj_error = (
img.mean_reprojection_error()
if callable(img.mean_reprojection_error)
else 0.0
)
except Exception:
reproj_error = 0.0
reprojection_errors.append(reproj_error)
except Exception as e:
logger.warning(f"Failed to extract pose for image {img.name}: {e}")
continue
# Align BA poses to match the order of input images
# COLMAP may not reconstruct all images, so we need to match by name
ordered_ba_poses = []
for img_path in image_paths:
img_name = Path(img_path).name
found = False
for img_id in sorted(ba_reconstruction.images.keys()):
img = ba_reconstruction.images[img_id]
if img.name == img_name:
if not img.has_pose:
logger.warning(f"Image {img_name} has no pose in BA reconstruction")
break
try:
# Extract pose same way as above
pose = img.cam_from_world()
R = pose.rotation.matrix()
t = pose.translation
c2w = np.eye(4)
c2w[:3, :3] = R
c2w[:3, 3] = t
w2c = np.linalg.inv(c2w)
ordered_ba_poses.append(w2c[:3, :]) # Extract 3x4 w2c matrix
found = True
except Exception as e:
logger.warning(f"Failed to extract pose for {img_name}: {e}")
break
if not found:
logger.warning(
f"Image {img_name} not found in BA reconstruction. Using initial pose."
)
# Use initial pose if BA didn't reconstruct this image
idx = image_name_to_idx[img_name]
ordered_ba_poses.append(initial_poses[idx])
if not ordered_ba_poses:
return {
"success": False,
"error_message": "No poses extracted from reconstruction",
"poses": None,
"reprojection_error": None,
}
# Ensure all poses are 3x4
ordered_ba_poses_3x4 = []
for pose in ordered_ba_poses:
if pose.shape == (3, 4):
ordered_ba_poses_3x4.append(pose)
elif pose.shape == (4, 4):
ordered_ba_poses_3x4.append(pose[:3, :])
else:
logger.warning(f"Unexpected pose shape: {pose.shape}, skipping")
# Use identity as fallback
pose_3x4 = np.eye(3, 4)
ordered_ba_poses_3x4.append(pose_3x4)
return {
"success": True,
"poses": np.array(ordered_ba_poses_3x4),
"reprojection_error": (
np.mean(reprojection_errors) if reprojection_errors else None
),
}
except Exception as e:
logger.error(f"COLMAP reconstruction failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return {
"success": False,
"error_message": str(e),
"poses": None,
"reprojection_error": None,
}
def _pose_3x4_to_4x4(self, pose: np.ndarray) -> np.ndarray:
"""Convert 3x4 pose to 4x4 homogeneous matrix."""
if pose.shape == (4, 4):
return pose
pose_4x4 = np.eye(4, dtype=pose.dtype)
pose_4x4[:3, :] = pose
return pose_4x4
def _run_simplified_ba(
self,
image_paths: List[str],
initial_poses: np.ndarray,
intrinsics: Optional[np.ndarray] = None,
) -> Dict:
"""Simplified BA that just returns initial poses (for testing)."""
logger.warning(
"Using simplified BA (no actual optimization). Full BA requires triangulation."
)
return {
"success": True,
"poses": initial_poses,
"reprojection_error": 0.0,
}
def _compute_pose_error(
self,
poses1: np.ndarray,
poses2: np.ndarray,
) -> Dict:
"""
Compute pose error between two sets of poses.
Returns:
Dictionary with error metrics
"""
# Align poses (Procrustes alignment)
poses1_aligned = self._align_trajectories(poses1, poses2)
rotation_errors = []
translation_errors = []
for i in range(len(poses1)):
R1 = poses1_aligned[i][:3, :3]
R2 = poses2[i][:3, :3]
t1 = poses1_aligned[i][:3, 3]
t2 = poses2[i][:3, 3]
# Rotation error: geodesic distance
R_diff = R1 @ R2.T
trace = np.trace(R_diff)
angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1))
angle_deg = np.degrees(angle_rad)
rotation_errors.append(angle_deg)
# Translation error
trans_error = np.linalg.norm(t1 - t2)
translation_errors.append(trans_error)
# Compute scene scale for relative translation error
scene_scale = np.percentile(translation_errors, 75) if translation_errors else 1.0
return {
"rotation_errors_deg": rotation_errors,
"translation_errors": translation_errors,
"max_rotation_error_deg": np.max(rotation_errors),
"mean_rotation_error_deg": np.mean(rotation_errors),
"max_translation_error": np.max(translation_errors),
"mean_translation_error": np.mean(translation_errors),
"scene_scale": scene_scale,
}
def _align_trajectories(
self,
poses1: np.ndarray,
poses2: np.ndarray,
) -> np.ndarray:
"""
Align trajectory 1 to trajectory 2 using Procrustes alignment.
"""
# Extract centers
centers1 = poses1[:, :3, 3]
centers2 = poses2[:, :3, 3]
# Center both trajectories
center1_mean = centers1.mean(axis=0)
center2_mean = centers2.mean(axis=0)
centers1_centered = centers1 - center1_mean
centers2_centered = centers2 - center2_mean
# Compute scale
scale1 = np.linalg.norm(centers1_centered, axis=1).mean()
scale2 = np.linalg.norm(centers2_centered, axis=1).mean()
scale = scale2 / (scale1 + 1e-8)
# Compute rotation (SVD)
H = centers1_centered.T @ centers2_centered
U, _, Vt = np.linalg.svd(H)
R_align = Vt.T @ U.T
# Apply alignment
poses1_aligned = poses1.copy()
for i in range(len(poses1)):
# Align rotation
R_orig = poses1[i][:3, :3]
R_aligned = R_align @ R_orig
poses1_aligned[i][:3, :3] = R_aligned
# Align translation
t_orig = poses1[i][:3, 3]
t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean
poses1_aligned[i][:3, 3] = t_aligned
return poses1_aligned