|
|
|
|
|
""" |
|
|
Run BA validation with real-time GUI visualization. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import sys |
|
|
import threading |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from ylff.services.arkit_processor import ARKitProcessor |
|
|
from ylff.services.ba_validator import BAValidator |
|
|
from ylff.utils.model_loader import load_da3_model |
|
|
from ylff.utils.visualization_gui import create_gui |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray) -> Dict: |
|
|
"""Compute pose error between two sets of poses.""" |
|
|
|
|
|
centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3] |
|
|
centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3] |
|
|
|
|
|
|
|
|
center1_mean = centers1.mean(axis=0) |
|
|
center2_mean = centers2.mean(axis=0) |
|
|
|
|
|
centers1_centered = centers1 - center1_mean |
|
|
centers2_centered = centers2 - center2_mean |
|
|
|
|
|
|
|
|
scale1 = np.linalg.norm(centers1_centered, axis=1).mean() |
|
|
scale2 = np.linalg.norm(centers2_centered, axis=1).mean() |
|
|
scale = scale2 / (scale1 + 1e-8) |
|
|
|
|
|
|
|
|
H = centers1_centered.T @ centers2_centered |
|
|
U, _, Vt = np.linalg.svd(H) |
|
|
R_align = Vt.T @ U.T |
|
|
|
|
|
|
|
|
poses1_aligned = poses1.copy() |
|
|
for i in range(len(poses1)): |
|
|
if poses1.shape[1] == 4: |
|
|
R_orig = poses1[i][:3, :3] |
|
|
t_orig = poses1[i][:3, 3] |
|
|
else: |
|
|
R_orig = poses1[i][:3, :3] |
|
|
t_orig = poses1[i][:3, 3] |
|
|
|
|
|
R_aligned = R_align @ R_orig |
|
|
t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean |
|
|
|
|
|
if poses1_aligned.shape[1] == 4: |
|
|
poses1_aligned[i][:3, :3] = R_aligned |
|
|
poses1_aligned[i][:3, 3] = t_aligned |
|
|
else: |
|
|
poses1_aligned[i][:3, :3] = R_aligned |
|
|
poses1_aligned[i][:3, 3] = t_aligned |
|
|
|
|
|
|
|
|
rotation_errors = [] |
|
|
translation_errors = [] |
|
|
|
|
|
for i in range(len(poses1)): |
|
|
if poses1_aligned.shape[1] == 4: |
|
|
R1 = poses1_aligned[i][:3, :3] |
|
|
R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3] |
|
|
t1 = poses1_aligned[i][:3, 3] |
|
|
t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3] |
|
|
else: |
|
|
R1 = poses1_aligned[i][:3, :3] |
|
|
R2 = poses2[i][:3, :3] |
|
|
t1 = poses1_aligned[i][:3, 3] |
|
|
t2 = poses2[i][:3, 3] |
|
|
|
|
|
|
|
|
R_diff = R1 @ R2.T |
|
|
trace = np.trace(R_diff) |
|
|
angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1)) |
|
|
angle_deg = np.degrees(angle_rad) |
|
|
rotation_errors.append(angle_deg) |
|
|
|
|
|
|
|
|
trans_error = np.linalg.norm(t1 - t2) |
|
|
translation_errors.append(trans_error) |
|
|
|
|
|
return { |
|
|
"rotation_errors_deg": rotation_errors, |
|
|
"translation_errors": translation_errors, |
|
|
"mean_rotation_error_deg": np.mean(rotation_errors), |
|
|
"max_rotation_error_deg": np.max(rotation_errors), |
|
|
"mean_translation_error": np.mean(translation_errors), |
|
|
} |
|
|
|
|
|
|
|
|
def run_validation_with_gui( |
|
|
gui, |
|
|
arkit_dir: Path, |
|
|
output_dir: Path, |
|
|
max_frames: int = None, |
|
|
frame_interval: int = 1, |
|
|
device: str = "cpu", |
|
|
): |
|
|
"""Run validation and update GUI progressively.""" |
|
|
|
|
|
def validation_thread(): |
|
|
try: |
|
|
|
|
|
video_path = None |
|
|
metadata_path = None |
|
|
|
|
|
for video_file in (arkit_dir / "videos").glob("*.MOV"): |
|
|
video_path = video_file |
|
|
break |
|
|
|
|
|
for json_file in (arkit_dir / "json-metadata").glob("*.json"): |
|
|
metadata_path = json_file |
|
|
break |
|
|
|
|
|
if not video_path or not metadata_path: |
|
|
gui.add_status_message("ERROR: ARKit files not found") |
|
|
return |
|
|
|
|
|
gui.add_status_message(f"Processing ARKit data: {video_path.name}") |
|
|
|
|
|
|
|
|
processor = ARKitProcessor(video_path, metadata_path) |
|
|
arkit_data = processor.process_for_ba_validation( |
|
|
output_dir=output_dir, |
|
|
max_frames=max_frames, |
|
|
frame_interval=frame_interval, |
|
|
use_good_tracking_only=False, |
|
|
) |
|
|
|
|
|
image_paths = arkit_data["image_paths"] |
|
|
arkit_poses_c2w = arkit_data["arkit_poses_c2w"] |
|
|
arkit_poses_w2c = arkit_data[ |
|
|
"arkit_poses_w2c" |
|
|
] |
|
|
|
|
|
|
|
|
from ylff.coordinate_utils import convert_arkit_to_opencv |
|
|
|
|
|
arkit_poses_c2w_opencv = np.array( |
|
|
[convert_arkit_to_opencv(p) for p in arkit_poses_c2w] |
|
|
) |
|
|
|
|
|
total_frames = len(image_paths) |
|
|
gui.add_progress_update(0, total_frames) |
|
|
gui.add_status_message(f"Extracted {total_frames} frames. Running DA3 inference...") |
|
|
|
|
|
|
|
|
images = [] |
|
|
for img_path in image_paths: |
|
|
img = cv2.imread(str(img_path)) |
|
|
if img is not None: |
|
|
images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
|
|
|
model = load_da3_model("depth-anything/DA3-LARGE", device=device) |
|
|
gui.add_status_message("Running DA3 inference...") |
|
|
|
|
|
da3_intrinsics = None |
|
|
|
|
|
with torch.no_grad(): |
|
|
da3_output = model.inference(images) |
|
|
da3_poses_all = da3_output.extrinsics |
|
|
da3_intrinsics = ( |
|
|
da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for i, (arkit_pose_c2w_opencv, da3_pose) in enumerate( |
|
|
zip(arkit_poses_c2w_opencv, da3_poses_all) |
|
|
): |
|
|
gui.add_frame_data( |
|
|
frame_idx=i, |
|
|
arkit_pose=arkit_pose_c2w_opencv, |
|
|
da3_pose=da3_pose, |
|
|
) |
|
|
gui.add_progress_update(i + 1, total_frames) |
|
|
time.sleep(0.1) |
|
|
|
|
|
gui.add_status_message("DA3 inference complete. Running BA validation...") |
|
|
|
|
|
|
|
|
validator = BAValidator( |
|
|
accept_threshold=2.0, |
|
|
reject_threshold=30.0, |
|
|
work_dir=output_dir / "ba_work", |
|
|
) |
|
|
|
|
|
ba_result = validator.validate( |
|
|
images=images, |
|
|
poses_model=da3_poses_all, |
|
|
intrinsics=da3_intrinsics, |
|
|
) |
|
|
|
|
|
if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: |
|
|
ba_poses = ba_result["poses_ba"] |
|
|
|
|
|
|
|
|
ba_pose_dict = {i: ba_poses[i] for i in range(len(ba_poses))} |
|
|
|
|
|
|
|
|
da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c) |
|
|
ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c) |
|
|
da3_vs_ba = compute_pose_error(da3_poses_all, ba_poses) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(images)): |
|
|
errors = {} |
|
|
if i < len(da3_vs_arkit["rotation_errors_deg"]): |
|
|
errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i] |
|
|
errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i] |
|
|
if i < len(ba_vs_arkit["rotation_errors_deg"]): |
|
|
errors["ba_vs_arkit_rot"] = ba_vs_arkit["rotation_errors_deg"][i] |
|
|
errors["ba_vs_arkit_trans"] = ba_vs_arkit["translation_errors"][i] |
|
|
if i < len(da3_vs_ba["rotation_errors_deg"]): |
|
|
errors["da3_vs_ba_rot"] = da3_vs_ba["rotation_errors_deg"][i] |
|
|
errors["da3_vs_ba_trans"] = da3_vs_ba["translation_errors"][i] |
|
|
|
|
|
ba_pose = ba_pose_dict.get(i) |
|
|
|
|
|
gui.add_frame_data( |
|
|
frame_idx=i, |
|
|
ba_pose=ba_pose, |
|
|
errors=errors, |
|
|
) |
|
|
time.sleep(0.05) |
|
|
|
|
|
gui.add_status_message("BA validation complete!") |
|
|
else: |
|
|
gui.add_status_message("BA validation failed") |
|
|
|
|
|
da3_vs_arkit = compute_pose_error(da3_poses_all, arkit_poses_w2c) |
|
|
for i in range(len(images)): |
|
|
errors = {} |
|
|
if i < len(da3_vs_arkit["rotation_errors_deg"]): |
|
|
errors["da3_vs_arkit_rot"] = da3_vs_arkit["rotation_errors_deg"][i] |
|
|
errors["da3_vs_arkit_trans"] = da3_vs_arkit["translation_errors"][i] |
|
|
gui.add_frame_data(frame_idx=i, errors=errors) |
|
|
time.sleep(0.05) |
|
|
|
|
|
gui.update_status("Complete", is_processing=False) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Validation error: {e}", exc_info=True) |
|
|
gui.add_status_message(f"ERROR: {str(e)}") |
|
|
gui.update_status("Error occurred", is_processing=False) |
|
|
|
|
|
|
|
|
thread = threading.Thread(target=validation_thread, daemon=True) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run BA validation with real-time GUI") |
|
|
parser.add_argument( |
|
|
"--arkit-dir", |
|
|
type=Path, |
|
|
default=None, |
|
|
help="Directory containing ARKit video and metadata", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=Path, |
|
|
default=project_root / "data" / "arkit_ba_validation_gui", |
|
|
help="Output directory for results", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-frames", type=int, default=None, help="Maximum number of frames to process" |
|
|
) |
|
|
parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame") |
|
|
parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.arkit_dir is None: |
|
|
args.arkit_dir = project_root / "assets" / "examples" / "ARKit" |
|
|
if args.output_dir is None: |
|
|
args.output_dir = project_root / "data" / "arkit_ba_validation_gui" |
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
gui = create_gui() |
|
|
|
|
|
|
|
|
run_validation_with_gui( |
|
|
gui, |
|
|
args.arkit_dir, |
|
|
args.output_dir, |
|
|
max_frames=args.max_frames, |
|
|
frame_interval=args.frame_interval, |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
|
|
|
gui.run() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|