#!/usr/bin/env python3 """ Run BA validation on ARKit data. Compares DA3 poses vs ARKit poses (ground truth) and vs COLMAP BA. """ import json import logging import sys from pathlib import Path from typing import Dict import numpy as np import torch # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from ylff.services.arkit_processor import ARKitProcessor # noqa: E402 from ylff.services.ba_validator import BAValidator # noqa: E402 from ylff.utils.model_loader import load_da3_model # noqa: E402 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray, verbose: bool = False) -> Dict: """Compute pose error between two sets of poses.""" # Align trajectories centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3] centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3] # Center both center1_mean = centers1.mean(axis=0) center2_mean = centers2.mean(axis=0) centers1_centered = centers1 - center1_mean centers2_centered = centers2 - center2_mean # Compute scale scale1 = np.linalg.norm(centers1_centered, axis=1).mean() scale2 = np.linalg.norm(centers2_centered, axis=1).mean() scale = scale2 / (scale1 + 1e-8) if verbose: logger.info(f" Alignment scale factor: {scale:.6f}") logger.info(f" Scale1 (poses1): {scale1:.6f}") logger.info(f" Scale2 (poses2): {scale2:.6f}") # Compute rotation (SVD) H = centers1_centered.T @ centers2_centered U, _, Vt = np.linalg.svd(H) R_align = Vt.T @ U.T if verbose: # Check if R_align is a valid rotation matrix det = np.linalg.det(R_align) logger.info(f" Alignment rotation det: {det:.6f} (should be ~1.0)") logger.info(f" Alignment rotation trace: {np.trace(R_align):.3f}") # Align poses poses1_aligned = poses1.copy() for i in range(len(poses1)): if poses1.shape[1] == 4: R_orig = poses1[i][:3, :3] t_orig = poses1[i][:3, 3] else: R_orig = poses1[i][:3, :3] t_orig = poses1[i][:3, 3] R_aligned = R_align @ R_orig t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean if poses1_aligned.shape[1] == 4: poses1_aligned[i][:3, :3] = R_aligned poses1_aligned[i][:3, 3] = t_aligned else: poses1_aligned[i][:3, :3] = R_aligned poses1_aligned[i][:3, 3] = t_aligned # Compute rotation errors rotation_errors = [] translation_errors = [] for i in range(len(poses1)): if poses1_aligned.shape[1] == 4: R1 = poses1_aligned[i][:3, :3] R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3] t1 = poses1_aligned[i][:3, 3] t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3] else: R1 = poses1_aligned[i][:3, :3] R2 = poses2[i][:3, :3] t1 = poses1_aligned[i][:3, 3] t2 = poses2[i][:3, 3] # Rotation error R_diff = R1 @ R2.T trace = np.trace(R_diff) angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1)) angle_deg = np.degrees(angle_rad) rotation_errors.append(angle_deg) # Translation error trans_error = np.linalg.norm(t1 - t2) translation_errors.append(trans_error) result = { "rotation_errors_deg": rotation_errors, "translation_errors": translation_errors, "mean_rotation_error_deg": np.mean(rotation_errors), "max_rotation_error_deg": np.max(rotation_errors), "mean_translation_error": np.mean(translation_errors), "alignment_info": { "scale_factor": float(scale), "center1_mean": center1_mean.tolist(), "center2_mean": center2_mean.tolist(), "rotation_det": float(np.linalg.det(R_align)), }, } if verbose: logger.info(" Alignment info saved to results") return result def main(): import argparse parser = argparse.ArgumentParser(description="Run BA validation on ARKit data") parser.add_argument( "--arkit-dir", type=Path, default=project_root / "assets" / "examples" / "ARKit", help="Directory containing ARKit video and metadata", ) parser.add_argument( "--output-dir", type=Path, default=project_root / "data" / "arkit_ba_validation", help="Output directory for results", ) parser.add_argument( "--max-frames", type=int, default=None, help="Maximum number of frames to process" ) parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame") parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference") args = parser.parse_args() # Set defaults if not provided if args.arkit_dir is None: args.arkit_dir = project_root / "assets" / "examples" / "ARKit" if args.output_dir is None: args.output_dir = project_root / "data" / "arkit_ba_validation" args.output_dir.mkdir(parents=True, exist_ok=True) # Find ARKit files video_path = None metadata_path = None for video_file in (args.arkit_dir / "videos").glob("*.MOV"): video_path = video_file break for json_file in (args.arkit_dir / "json-metadata").glob("*.json"): metadata_path = json_file break if not video_path or not metadata_path: logger.error(f"ARKit files not found in {args.arkit_dir}") logger.error("Expected: videos/*.MOV and json-metadata/*.json") return logger.info(f"ARKit video: {video_path}") logger.info(f"ARKit metadata: {metadata_path}") # Process ARKit data logger.info("\n=== Processing ARKit Data ===") processor = ARKitProcessor(video_path, metadata_path) arkit_data = processor.process_for_ba_validation( output_dir=args.output_dir, max_frames=args.max_frames, frame_interval=args.frame_interval, use_good_tracking_only=True, ) image_paths = arkit_data["image_paths"] arkit_poses_c2w = arkit_data["arkit_poses_c2w"] # arkit_poses_w2c = arkit_data["arkit_poses_w2c"] # Not used in this script # arkit_intrinsics = arkit_data["arkit_intrinsics"] # Not used in this script # Convert ARKit c2w poses to OpenCV convention for proper comparison from ylff.coordinate_utils import convert_arkit_to_opencv arkit_poses_c2w_opencv = np.array([convert_arkit_to_opencv(p) for p in arkit_poses_c2w]) logger.info(f"Processed {len(image_paths)} frames") # Run DA3 inference logger.info("\n=== Running DA3 Inference ===") model = load_da3_model("depth-anything/DA3-LARGE", device=args.device) import cv2 images = [] for img_path in image_paths: img = cv2.imread(str(img_path)) if img is not None: images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) logger.info(f"Running DA3 on {len(images)} images...") with torch.no_grad(): da3_output = model.inference(images) da3_poses = da3_output.extrinsics # (N, 3, 4) w2c da3_intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None logger.info(f"DA3 poses: {da3_poses.shape}") # Compare DA3 vs ARKit # Convert ARKit c2w to w2c in OpenCV convention for comparison arkit_poses_w2c_opencv = np.array([np.linalg.inv(p)[:3, :] for p in arkit_poses_c2w_opencv]) logger.info("\n=== Comparing DA3 vs ARKit (Ground Truth) ===") # Log pose statistics before comparison logger.info("\nPose Statistics (before alignment):") da3_centers = da3_poses[:, :3, 3] arkit_centers = arkit_poses_w2c_opencv[:, :3, 3] logger.info(f" DA3 translation range: [{da3_centers.min(axis=0)}, {da3_centers.max(axis=0)}]") da3_norms = np.linalg.norm(da3_centers, axis=1) logger.info( f" DA3 translation magnitude: mean={da3_norms.mean():.3f}, " f"std={da3_norms.std():.3f}" ) logger.info( f" ARKit translation range: [{arkit_centers.min(axis=0)}, {arkit_centers.max(axis=0)}]" ) arkit_norms = np.linalg.norm(arkit_centers, axis=1) logger.info( f" ARKit translation magnitude: mean={arkit_norms.mean():.3f}, " f"std={arkit_norms.std():.3f}" ) da3_vs_arkit = compute_pose_error(da3_poses, arkit_poses_w2c_opencv, verbose=True) logger.info("\nDA3 vs ARKit Error Summary:") logger.info(f" Mean rotation error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°") logger.info(f" Median rotation error: {np.median(da3_vs_arkit['rotation_errors_deg']):.2f}°") logger.info(f" Max rotation error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°") logger.info(f" Min rotation error: {np.min(da3_vs_arkit['rotation_errors_deg']):.2f}°") logger.info(f" Std rotation error: {np.std(da3_vs_arkit['rotation_errors_deg']):.2f}°") logger.info(f" Mean translation error: {da3_vs_arkit['mean_translation_error']:.3f} m") logger.info(f" Max translation error: {np.max(da3_vs_arkit['translation_errors']):.3f} m") # Alignment diagnostics if "alignment_info" in da3_vs_arkit: align_info = da3_vs_arkit["alignment_info"] logger.info("\nAlignment Diagnostics:") logger.info( f" Scale factor: {align_info['scale_factor']:.6f} (should be ~1.0 if scales match)" ) logger.info(f" Rotation matrix det: {align_info['rotation_det']:.6f} (should be ~1.0)") logger.info(f" Center1 (DA3) mean: {align_info['center1_mean']}") logger.info(f" Center2 (ARKit) mean: {align_info['center2_mean']}") # Per-frame breakdown logger.info("\nPer-Frame Error Breakdown:") logger.info(f" {'Frame':<8} {'Rot Error (°)':<15} {'Trans Error (m)':<15} {'Category':<20}") logger.info(f" {'-' * 8} {'-' * 15} {'-' * 15} {'-' * 20}") for i, (rot_err, trans_err) in enumerate( zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"]) ): if rot_err < 2.0: category = "Accepted" elif rot_err < 30.0: category = "Rejected-Learnable" else: category = "Rejected-Outlier" logger.info(f" {i:<8} {rot_err:<15.2f} {trans_err:<15.3f} {category:<20}") # Error distribution rot_errors = da3_vs_arkit["rotation_errors_deg"] logger.info("\nError Distribution (Rotation):") logger.info(f" Q1 (25th percentile): {np.percentile(rot_errors, 25):.2f}°") logger.info(f" Q2 (50th percentile/median): {np.percentile(rot_errors, 50):.2f}°") logger.info(f" Q3 (75th percentile): {np.percentile(rot_errors, 75):.2f}°") logger.info(f" 90th percentile: {np.percentile(rot_errors, 90):.2f}°") logger.info(f" 95th percentile: {np.percentile(rot_errors, 95):.2f}°") logger.info(f" 99th percentile: {np.percentile(rot_errors, 99):.2f}°") # Run BA validation logger.info("\n=== Running BA Validation ===") validator = BAValidator( accept_threshold=2.0, reject_threshold=30.0, work_dir=args.output_dir / "ba_work", ) ba_result = validator.validate( images=images, poses_model=da3_poses, intrinsics=da3_intrinsics, ) if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: ba_poses = ba_result["poses_ba"] # (N, 3, 4) w2c # Compare BA vs ARKit logger.info("\n=== Comparing BA vs ARKit (Ground Truth) ===") ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c_opencv) logger.info("BA vs ARKit:") logger.info(f" Mean rotation error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°") logger.info(f" Max rotation error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°") logger.info(f" Mean translation error: {ba_vs_arkit['mean_translation_error']:.2f}") # Compare DA3 vs BA logger.info("\n=== Comparing DA3 vs BA ===") da3_vs_ba = compute_pose_error(da3_poses, ba_poses) logger.info("DA3 vs BA:") logger.info(f" Mean rotation error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°") logger.info(f" Max rotation error: {da3_vs_ba['max_rotation_error_deg']:.2f}°") # Save DA3 and BA poses for visualization np.save(args.output_dir / "da3_poses_w2c.npy", da3_poses) if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: np.save(args.output_dir / "ba_poses_w2c.npy", ba_poses) # Calculate frame categorization from DA3 vs ARKit errors rot_errors = da3_vs_arkit["rotation_errors_deg"] accepted_frames = [] rejected_learnable_frames = [] rejected_outlier_frames = [] accept_threshold = 2.0 reject_threshold = 30.0 for i, err in enumerate(rot_errors): if err < accept_threshold: accepted_frames.append(i) elif err < reject_threshold: rejected_learnable_frames.append(i) else: rejected_outlier_frames.append(i) frame_categorization = { "accepted": { "count": len(accepted_frames), "percentage": ( 100.0 * len(accepted_frames) / len(rot_errors) if rot_errors else 0.0 ), "frame_indices": accepted_frames, }, "rejected_learnable": { "count": len(rejected_learnable_frames), "percentage": ( 100.0 * len(rejected_learnable_frames) / len(rot_errors) if rot_errors else 0.0 ), "frame_indices": rejected_learnable_frames, }, "rejected_outlier": { "count": len(rejected_outlier_frames), "percentage": ( 100.0 * len(rejected_outlier_frames) / len(rot_errors) if rot_errors else 0.0 ), "frame_indices": rejected_outlier_frames, }, "total_frames": len(rot_errors), } logger.info("\n=== Frame Categorization (DA3 vs ARKit) ===") accepted_info = frame_categorization["accepted"] learnable_info = frame_categorization["rejected_learnable"] outlier_info = frame_categorization["rejected_outlier"] total_frames = frame_categorization["total_frames"] logger.info( f" Accepted (< {accept_threshold}°): " f"{accepted_info['count']}/{total_frames} " f"({accepted_info['percentage']:.1f}%)" ) logger.info( f" Rejected-Learnable ({accept_threshold}-{reject_threshold}°): " f"{learnable_info['count']}/{total_frames} " f"({learnable_info['percentage']:.1f}%)" ) logger.info( f" Rejected-Outlier (> {reject_threshold}°): " f"{outlier_info['count']}/{total_frames} " f"({outlier_info['percentage']:.1f}%)" ) # Add detailed diagnostics diagnostics = { "pose_statistics": { "da3": { "translation_range": { "min": da3_centers.min(axis=0).tolist(), "max": da3_centers.max(axis=0).tolist(), "mean_magnitude": float(np.linalg.norm(da3_centers, axis=1).mean()), "std_magnitude": float(np.linalg.norm(da3_centers, axis=1).std()), } }, "arkit": { "translation_range": { "min": arkit_centers.min(axis=0).tolist(), "max": arkit_centers.max(axis=0).tolist(), "mean_magnitude": float(np.linalg.norm(arkit_centers, axis=1).mean()), "std_magnitude": float(np.linalg.norm(arkit_centers, axis=1).std()), } }, }, "error_distribution": { "rotation_errors_deg": { "q1": float(np.percentile(rot_errors, 25)), "median": float(np.percentile(rot_errors, 50)), "q3": float(np.percentile(rot_errors, 75)), "p90": float(np.percentile(rot_errors, 90)), "p95": float(np.percentile(rot_errors, 95)), "p99": float(np.percentile(rot_errors, 99)), }, "translation_errors": { "mean": float(np.mean(da3_vs_arkit["translation_errors"])), "median": float(np.median(da3_vs_arkit["translation_errors"])), "max": float(np.max(da3_vs_arkit["translation_errors"])), "std": float(np.std(da3_vs_arkit["translation_errors"])), }, }, "per_frame_errors": [ { "frame_idx": i, "rotation_error_deg": float(rot_err), "translation_error_m": float(trans_err), "category": ( "accepted" if rot_err < 2.0 else ("rejected_learnable" if rot_err < 30.0 else "rejected_outlier") ), } for i, (rot_err, trans_err) in enumerate( zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"]) ) ], "da3_vs_arkit": {"alignment_info": da3_vs_arkit.get("alignment_info", {})}, } # Save results results = { "da3_vs_arkit": da3_vs_arkit, "ba_vs_arkit": ba_vs_arkit, "da3_vs_ba": da3_vs_ba, "ba_result": { "status": ba_result["status"], "error": ba_result.get("error"), "reprojection_error": ba_result.get("reprojection_error"), }, "frame_categorization": frame_categorization, "diagnostics": diagnostics, "num_frames": len(images), } results_path = args.output_dir / "validation_results.json" with open(results_path, "w") as f: json.dump(results, f, indent=2, default=str) logger.info(f"\n✓ Results saved to {results_path}") logger.info("✓ Poses saved for visualization") else: logger.warning("BA validation failed, skipping BA comparisons") logger.info("\n=== Complete ===") if __name__ == "__main__": main()