#!/usr/bin/env python3 """ Run BA validation with real-time GUI visualization. """ import logging import sys import threading import time from pathlib import Path from typing import Dict import cv2 import numpy as np import torch # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from ylff.services.arkit_processor import ARKitProcessor # noqa: E402 from ylff.services.ba_validator import BAValidator # noqa: E402 from ylff.utils.model_loader import load_da3_model # noqa: E402 from ylff.utils.visualization_gui import create_gui # noqa: E402 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray) -> Dict: """Compute pose error between two sets of poses.""" # Align trajectories centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3] centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3] # Center both center1_mean = centers1.mean(axis=0) center2_mean = centers2.mean(axis=0) centers1_centered = centers1 - center1_mean centers2_centered = centers2 - center2_mean # Compute scale scale1 = np.linalg.norm(centers1_centered, axis=1).mean() scale2 = np.linalg.norm(centers2_centered, axis=1).mean() scale = scale2 / (scale1 + 1e-8) # Compute rotation (SVD) H = centers1_centered.T @ centers2_centered U, _, Vt = np.linalg.svd(H) R_align = Vt.T @ U.T # Align poses poses1_aligned = poses1.copy() for i in range(len(poses1)): if poses1.shape[1] == 4: R_orig = poses1[i][:3, :3] t_orig = poses1[i][:3, 3] else: R_orig = poses1[i][:3, :3] t_orig = poses1[i][:3, 3] R_aligned = R_align @ R_orig t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean if poses1_aligned.shape[1] == 4: poses1_aligned[i][:3, :3] = R_aligned poses1_aligned[i][:3, 3] = t_aligned else: poses1_aligned[i][:3, :3] = R_aligned poses1_aligned[i][:3, 3] = t_aligned # Compute rotation errors rotation_errors = [] translation_errors = [] for i in range(len(poses1)): if poses1_aligned.shape[1] == 4: R1 = poses1_aligned[i][:3, :3] R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3] t1 = poses1_aligned[i][:3, 3] t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3] else: R1 = poses1_aligned[i][:3, :3] R2 = poses2[i][:3, :3] t1 = poses1_aligned[i][:3, 3] t2 = poses2[i][:3, 3] # Rotation error R_diff = R1 @ R2.T trace = np.trace(R_diff) angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1)) angle_deg = np.degrees(angle_rad) rotation_errors.append(angle_deg) # Translation error trans_error = np.linalg.norm(t1 - t2) translation_errors.append(trans_error) return { "rotation_errors_deg": rotation_errors, "translation_errors": translation_errors, "mean_rotation_error_deg": np.mean(rotation_errors), "max_rotation_error_deg": np.max(rotation_errors), "mean_translation_error": np.mean(translation_errors), } def run_validation_with_gui( gui, arkit_dir: Path, output_dir: Path, max_frames: int = None, frame_interval: int = 1, device: str = "cpu", ): """Run validation and update GUI progressively.""" def validation_thread(): try: # Find ARKit files video_path = None metadata_path = None for video_file in (arkit_dir / "videos").glob("*.MOV"): video_path = video_file break for json_file in (arkit_dir / "json-metadata").glob("*.json"): metadata_path = json_file break if not video_path or not metadata_path: gui.add_status_message("ERROR: ARKit files not found") return gui.add_status_message(f"Processing ARKit data: {video_path.name}") # Process ARKit data processor = ARKitProcessor(video_path, metadata_path) arkit_data = processor.process_for_ba_validation( output_dir=output_dir, max_frames=max_frames, frame_interval=frame_interval, use_good_tracking_only=False, ) image_paths = arkit_data["image_paths"] arkit_poses_c2w = arkit_data["arkit_poses_c2w"] arkit_poses_w2c = arkit_data[ "arkit_poses_w2c" ] # Already converted to OpenCV convention # Convert ARKit c2w poses to OpenCV convention for visualization from ylff.coordinate_utils import convert_arkit_to_opencv arkit_poses_c2w_opencv = np.array( [convert_arkit_to_opencv(p) for p in arkit_poses_c2w] ) total_frames = len(image_paths) gui.add_progress_update(0, total_frames) gui.add_status_message(f"Extracted {total_frames} frames. Running DA3 inference...") # Load images images = [] for img_path in image_paths: img = cv2.imread(str(img_path)) if img is not None: images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # Run DA3 inference (progressive updates) model = load_da3_model("depth-anything/DA3-LARGE", device=device) gui.add_status_message("Running DA3 inference...") da3_intrinsics = None with torch.no_grad(): da3_output = model.inference(images) da3_poses_all = da3_output.extrinsics da3_intrinsics = ( da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None ) # Update GUI with DA3 results # Use OpenCV-converted ARKit poses for visualization for i, (arkit_pose_c2w_opencv, da3_pose) in enumerate( zip(arkit_poses_c2w_opencv, da3_poses_all) ): gui.add_frame_data( frame_idx=i, arkit_pose=arkit_pose_c2w_opencv, # Already in OpenCV convention da3_pose=da3_pose, ) gui.add_progress_update(i + 1, total_frames) time.sleep(0.1) # Small delay for visualization gui.add_status_message("DA3 inference complete. Running BA validation...") # Run BA validation validator = BAValidator( accept_threshold=2.0, reject_threshold=30.0, work_dir=output_dir / "ba_work", ) ba_result = validator.validate( images=images, poses_model=da3_poses_all, intrinsics=da3_intrinsics, ) if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: ba_poses = ba_result["poses_ba"] # Create a dictionary mapping frame indices to BA poses ba_pose_dict = {i: ba_poses[i] for i in range(len(ba_poses))} # Compute errors and update GUI da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c) ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c) da3_vs_ba = compute_pose_error(da3_poses_all, ba_poses) # Update GUI with BA results and errors # Note: BA may not have poses for all frames - use indices directly # BA poses are already aligned to input order in ba_result for i in range(len(images)): errors = {} if i < len(da3_vs_arkit["rotation_errors_deg"]): errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i] errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i] if i < len(ba_vs_arkit["rotation_errors_deg"]): errors["ba_vs_arkit_rot"] = ba_vs_arkit["rotation_errors_deg"][i] errors["ba_vs_arkit_trans"] = ba_vs_arkit["translation_errors"][i] if i < len(da3_vs_ba["rotation_errors_deg"]): errors["da3_vs_ba_rot"] = da3_vs_ba["rotation_errors_deg"][i] errors["da3_vs_ba_trans"] = da3_vs_ba["translation_errors"][i] ba_pose = ba_pose_dict.get(i) gui.add_frame_data( frame_idx=i, ba_pose=ba_pose, errors=errors, ) time.sleep(0.05) gui.add_status_message("BA validation complete!") else: gui.add_status_message("BA validation failed") # Still update with DA3 vs ARKit errors da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c) for i in range(len(images)): errors = {} if i < len(da3_vs_arkit["rotation_errors_deg"]): errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i] errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i] gui.add_frame_data(frame_idx=i, errors=errors) time.sleep(0.05) gui.update_status("Complete", is_processing=False) except Exception as e: logger.error(f"Validation error: {e}", exc_info=True) gui.add_status_message(f"ERROR: {str(e)}") gui.update_status("Error occurred", is_processing=False) # Start validation in background thread thread = threading.Thread(target=validation_thread, daemon=True) thread.start() def main(): import argparse parser = argparse.ArgumentParser(description="Run BA validation with real-time GUI") parser.add_argument( "--arkit-dir", type=Path, default=None, help="Directory containing ARKit video and metadata", ) parser.add_argument( "--output-dir", type=Path, default=project_root / "data" / "arkit_ba_validation_gui", help="Output directory for results", ) parser.add_argument( "--max-frames", type=int, default=None, help="Maximum number of frames to process" ) parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame") parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference") args = parser.parse_args() # Set defaults if not provided if args.arkit_dir is None: args.arkit_dir = project_root / "assets" / "examples" / "ARKit" if args.output_dir is None: args.output_dir = project_root / "data" / "arkit_ba_validation_gui" args.output_dir.mkdir(parents=True, exist_ok=True) # Create GUI gui = create_gui() # Start validation in background run_validation_with_gui( gui, args.arkit_dir, args.output_dir, max_frames=args.max_frames, frame_interval=args.frame_interval, device=args.device, ) # Run GUI main loop gui.run() if __name__ == "__main__": main()