3d_model / scripts /experiments /run_arkit_ba_validation_gui.py
Azan
Clean deployment build (Squashed)
7a87926
#!/usr/bin/env python3
"""
Run BA validation with real-time GUI visualization.
"""
import logging
import sys
import threading
import time
from pathlib import Path
from typing import Dict
import cv2
import numpy as np
import torch
# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from ylff.services.arkit_processor import ARKitProcessor # noqa: E402
from ylff.services.ba_validator import BAValidator # noqa: E402
from ylff.utils.model_loader import load_da3_model # noqa: E402
from ylff.utils.visualization_gui import create_gui # noqa: E402
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray) -> Dict:
"""Compute pose error between two sets of poses."""
# Align trajectories
centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3]
centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3]
# Center both
center1_mean = centers1.mean(axis=0)
center2_mean = centers2.mean(axis=0)
centers1_centered = centers1 - center1_mean
centers2_centered = centers2 - center2_mean
# Compute scale
scale1 = np.linalg.norm(centers1_centered, axis=1).mean()
scale2 = np.linalg.norm(centers2_centered, axis=1).mean()
scale = scale2 / (scale1 + 1e-8)
# Compute rotation (SVD)
H = centers1_centered.T @ centers2_centered
U, _, Vt = np.linalg.svd(H)
R_align = Vt.T @ U.T
# Align poses
poses1_aligned = poses1.copy()
for i in range(len(poses1)):
if poses1.shape[1] == 4:
R_orig = poses1[i][:3, :3]
t_orig = poses1[i][:3, 3]
else:
R_orig = poses1[i][:3, :3]
t_orig = poses1[i][:3, 3]
R_aligned = R_align @ R_orig
t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean
if poses1_aligned.shape[1] == 4:
poses1_aligned[i][:3, :3] = R_aligned
poses1_aligned[i][:3, 3] = t_aligned
else:
poses1_aligned[i][:3, :3] = R_aligned
poses1_aligned[i][:3, 3] = t_aligned
# Compute rotation errors
rotation_errors = []
translation_errors = []
for i in range(len(poses1)):
if poses1_aligned.shape[1] == 4:
R1 = poses1_aligned[i][:3, :3]
R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3]
t1 = poses1_aligned[i][:3, 3]
t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3]
else:
R1 = poses1_aligned[i][:3, :3]
R2 = poses2[i][:3, :3]
t1 = poses1_aligned[i][:3, 3]
t2 = poses2[i][:3, 3]
# Rotation error
R_diff = R1 @ R2.T
trace = np.trace(R_diff)
angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1))
angle_deg = np.degrees(angle_rad)
rotation_errors.append(angle_deg)
# Translation error
trans_error = np.linalg.norm(t1 - t2)
translation_errors.append(trans_error)
return {
"rotation_errors_deg": rotation_errors,
"translation_errors": translation_errors,
"mean_rotation_error_deg": np.mean(rotation_errors),
"max_rotation_error_deg": np.max(rotation_errors),
"mean_translation_error": np.mean(translation_errors),
}
def run_validation_with_gui(
gui,
arkit_dir: Path,
output_dir: Path,
max_frames: int = None,
frame_interval: int = 1,
device: str = "cpu",
):
"""Run validation and update GUI progressively."""
def validation_thread():
try:
# Find ARKit files
video_path = None
metadata_path = None
for video_file in (arkit_dir / "videos").glob("*.MOV"):
video_path = video_file
break
for json_file in (arkit_dir / "json-metadata").glob("*.json"):
metadata_path = json_file
break
if not video_path or not metadata_path:
gui.add_status_message("ERROR: ARKit files not found")
return
gui.add_status_message(f"Processing ARKit data: {video_path.name}")
# Process ARKit data
processor = ARKitProcessor(video_path, metadata_path)
arkit_data = processor.process_for_ba_validation(
output_dir=output_dir,
max_frames=max_frames,
frame_interval=frame_interval,
use_good_tracking_only=False,
)
image_paths = arkit_data["image_paths"]
arkit_poses_c2w = arkit_data["arkit_poses_c2w"]
arkit_poses_w2c = arkit_data[
"arkit_poses_w2c"
] # Already converted to OpenCV convention
# Convert ARKit c2w poses to OpenCV convention for visualization
from ylff.coordinate_utils import convert_arkit_to_opencv
arkit_poses_c2w_opencv = np.array(
[convert_arkit_to_opencv(p) for p in arkit_poses_c2w]
)
total_frames = len(image_paths)
gui.add_progress_update(0, total_frames)
gui.add_status_message(f"Extracted {total_frames} frames. Running DA3 inference...")
# Load images
images = []
for img_path in image_paths:
img = cv2.imread(str(img_path))
if img is not None:
images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
# Run DA3 inference (progressive updates)
model = load_da3_model("depth-anything/DA3-LARGE", device=device)
gui.add_status_message("Running DA3 inference...")
da3_intrinsics = None
with torch.no_grad():
da3_output = model.inference(images)
da3_poses_all = da3_output.extrinsics
da3_intrinsics = (
da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None
)
# Update GUI with DA3 results
# Use OpenCV-converted ARKit poses for visualization
for i, (arkit_pose_c2w_opencv, da3_pose) in enumerate(
zip(arkit_poses_c2w_opencv, da3_poses_all)
):
gui.add_frame_data(
frame_idx=i,
arkit_pose=arkit_pose_c2w_opencv, # Already in OpenCV convention
da3_pose=da3_pose,
)
gui.add_progress_update(i + 1, total_frames)
time.sleep(0.1) # Small delay for visualization
gui.add_status_message("DA3 inference complete. Running BA validation...")
# Run BA validation
validator = BAValidator(
accept_threshold=2.0,
reject_threshold=30.0,
work_dir=output_dir / "ba_work",
)
ba_result = validator.validate(
images=images,
poses_model=da3_poses_all,
intrinsics=da3_intrinsics,
)
if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None:
ba_poses = ba_result["poses_ba"]
# Create a dictionary mapping frame indices to BA poses
ba_pose_dict = {i: ba_poses[i] for i in range(len(ba_poses))}
# Compute errors and update GUI
da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c)
ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c)
da3_vs_ba = compute_pose_error(da3_poses_all, ba_poses)
# Update GUI with BA results and errors
# Note: BA may not have poses for all frames - use indices directly
# BA poses are already aligned to input order in ba_result
for i in range(len(images)):
errors = {}
if i < len(da3_vs_arkit["rotation_errors_deg"]):
errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i]
errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i]
if i < len(ba_vs_arkit["rotation_errors_deg"]):
errors["ba_vs_arkit_rot"] = ba_vs_arkit["rotation_errors_deg"][i]
errors["ba_vs_arkit_trans"] = ba_vs_arkit["translation_errors"][i]
if i < len(da3_vs_ba["rotation_errors_deg"]):
errors["da3_vs_ba_rot"] = da3_vs_ba["rotation_errors_deg"][i]
errors["da3_vs_ba_trans"] = da3_vs_ba["translation_errors"][i]
ba_pose = ba_pose_dict.get(i)
gui.add_frame_data(
frame_idx=i,
ba_pose=ba_pose,
errors=errors,
)
time.sleep(0.05)
gui.add_status_message("BA validation complete!")
else:
gui.add_status_message("BA validation failed")
# Still update with DA3 vs ARKit errors
da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c)
for i in range(len(images)):
errors = {}
if i < len(da3_vs_arkit["rotation_errors_deg"]):
errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i]
errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i]
gui.add_frame_data(frame_idx=i, errors=errors)
time.sleep(0.05)
gui.update_status("Complete", is_processing=False)
except Exception as e:
logger.error(f"Validation error: {e}", exc_info=True)
gui.add_status_message(f"ERROR: {str(e)}")
gui.update_status("Error occurred", is_processing=False)
# Start validation in background thread
thread = threading.Thread(target=validation_thread, daemon=True)
thread.start()
def main():
import argparse
parser = argparse.ArgumentParser(description="Run BA validation with real-time GUI")
parser.add_argument(
"--arkit-dir",
type=Path,
default=None,
help="Directory containing ARKit video and metadata",
)
parser.add_argument(
"--output-dir",
type=Path,
default=project_root / "data" / "arkit_ba_validation_gui",
help="Output directory for results",
)
parser.add_argument(
"--max-frames", type=int, default=None, help="Maximum number of frames to process"
)
parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame")
parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference")
args = parser.parse_args()
# Set defaults if not provided
if args.arkit_dir is None:
args.arkit_dir = project_root / "assets" / "examples" / "ARKit"
if args.output_dir is None:
args.output_dir = project_root / "data" / "arkit_ba_validation_gui"
args.output_dir.mkdir(parents=True, exist_ok=True)
# Create GUI
gui = create_gui()
# Start validation in background
run_validation_with_gui(
gui,
args.arkit_dir,
args.output_dir,
max_frames=args.max_frames,
frame_interval=args.frame_interval,
device=args.device,
)
# Run GUI main loop
gui.run()
if __name__ == "__main__":
main()