|
|
|
|
|
""" |
|
|
Run BA validation on ARKit data. |
|
|
Compares DA3 poses vs ARKit poses (ground truth) and vs COLMAP BA. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
from ylff.services.arkit_processor import ARKitProcessor |
|
|
from ylff.services.ba_validator import BAValidator |
|
|
from ylff.utils.model_loader import load_da3_model |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray, verbose: bool = False) -> Dict: |
|
|
"""Compute pose error between two sets of poses.""" |
|
|
|
|
|
centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3] |
|
|
centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3] |
|
|
|
|
|
|
|
|
center1_mean = centers1.mean(axis=0) |
|
|
center2_mean = centers2.mean(axis=0) |
|
|
|
|
|
centers1_centered = centers1 - center1_mean |
|
|
centers2_centered = centers2 - center2_mean |
|
|
|
|
|
|
|
|
scale1 = np.linalg.norm(centers1_centered, axis=1).mean() |
|
|
scale2 = np.linalg.norm(centers2_centered, axis=1).mean() |
|
|
scale = scale2 / (scale1 + 1e-8) |
|
|
|
|
|
if verbose: |
|
|
logger.info(f" Alignment scale factor: {scale:.6f}") |
|
|
logger.info(f" Scale1 (poses1): {scale1:.6f}") |
|
|
logger.info(f" Scale2 (poses2): {scale2:.6f}") |
|
|
|
|
|
|
|
|
H = centers1_centered.T @ centers2_centered |
|
|
U, _, Vt = np.linalg.svd(H) |
|
|
R_align = Vt.T @ U.T |
|
|
|
|
|
if verbose: |
|
|
|
|
|
det = np.linalg.det(R_align) |
|
|
logger.info(f" Alignment rotation det: {det:.6f} (should be ~1.0)") |
|
|
logger.info(f" Alignment rotation trace: {np.trace(R_align):.3f}") |
|
|
|
|
|
|
|
|
poses1_aligned = poses1.copy() |
|
|
for i in range(len(poses1)): |
|
|
if poses1.shape[1] == 4: |
|
|
R_orig = poses1[i][:3, :3] |
|
|
t_orig = poses1[i][:3, 3] |
|
|
else: |
|
|
R_orig = poses1[i][:3, :3] |
|
|
t_orig = poses1[i][:3, 3] |
|
|
|
|
|
R_aligned = R_align @ R_orig |
|
|
t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean |
|
|
|
|
|
if poses1_aligned.shape[1] == 4: |
|
|
poses1_aligned[i][:3, :3] = R_aligned |
|
|
poses1_aligned[i][:3, 3] = t_aligned |
|
|
else: |
|
|
poses1_aligned[i][:3, :3] = R_aligned |
|
|
poses1_aligned[i][:3, 3] = t_aligned |
|
|
|
|
|
|
|
|
rotation_errors = [] |
|
|
translation_errors = [] |
|
|
|
|
|
for i in range(len(poses1)): |
|
|
if poses1_aligned.shape[1] == 4: |
|
|
R1 = poses1_aligned[i][:3, :3] |
|
|
R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3] |
|
|
t1 = poses1_aligned[i][:3, 3] |
|
|
t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3] |
|
|
else: |
|
|
R1 = poses1_aligned[i][:3, :3] |
|
|
R2 = poses2[i][:3, :3] |
|
|
t1 = poses1_aligned[i][:3, 3] |
|
|
t2 = poses2[i][:3, 3] |
|
|
|
|
|
|
|
|
R_diff = R1 @ R2.T |
|
|
trace = np.trace(R_diff) |
|
|
angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1)) |
|
|
angle_deg = np.degrees(angle_rad) |
|
|
rotation_errors.append(angle_deg) |
|
|
|
|
|
|
|
|
trans_error = np.linalg.norm(t1 - t2) |
|
|
translation_errors.append(trans_error) |
|
|
|
|
|
result = { |
|
|
"rotation_errors_deg": rotation_errors, |
|
|
"translation_errors": translation_errors, |
|
|
"mean_rotation_error_deg": np.mean(rotation_errors), |
|
|
"max_rotation_error_deg": np.max(rotation_errors), |
|
|
"mean_translation_error": np.mean(translation_errors), |
|
|
"alignment_info": { |
|
|
"scale_factor": float(scale), |
|
|
"center1_mean": center1_mean.tolist(), |
|
|
"center2_mean": center2_mean.tolist(), |
|
|
"rotation_det": float(np.linalg.det(R_align)), |
|
|
}, |
|
|
} |
|
|
|
|
|
if verbose: |
|
|
logger.info(" Alignment info saved to results") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run BA validation on ARKit data") |
|
|
parser.add_argument( |
|
|
"--arkit-dir", |
|
|
type=Path, |
|
|
default=project_root / "assets" / "examples" / "ARKit", |
|
|
help="Directory containing ARKit video and metadata", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=Path, |
|
|
default=project_root / "data" / "arkit_ba_validation", |
|
|
help="Output directory for results", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-frames", type=int, default=None, help="Maximum number of frames to process" |
|
|
) |
|
|
parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame") |
|
|
parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.arkit_dir is None: |
|
|
args.arkit_dir = project_root / "assets" / "examples" / "ARKit" |
|
|
if args.output_dir is None: |
|
|
args.output_dir = project_root / "data" / "arkit_ba_validation" |
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
video_path = None |
|
|
metadata_path = None |
|
|
|
|
|
for video_file in (args.arkit_dir / "videos").glob("*.MOV"): |
|
|
video_path = video_file |
|
|
break |
|
|
|
|
|
for json_file in (args.arkit_dir / "json-metadata").glob("*.json"): |
|
|
metadata_path = json_file |
|
|
break |
|
|
|
|
|
if not video_path or not metadata_path: |
|
|
logger.error(f"ARKit files not found in {args.arkit_dir}") |
|
|
logger.error("Expected: videos/*.MOV and json-metadata/*.json") |
|
|
return |
|
|
|
|
|
logger.info(f"ARKit video: {video_path}") |
|
|
logger.info(f"ARKit metadata: {metadata_path}") |
|
|
|
|
|
|
|
|
logger.info("\n=== Processing ARKit Data ===") |
|
|
processor = ARKitProcessor(video_path, metadata_path) |
|
|
|
|
|
arkit_data = processor.process_for_ba_validation( |
|
|
output_dir=args.output_dir, |
|
|
max_frames=args.max_frames, |
|
|
frame_interval=args.frame_interval, |
|
|
use_good_tracking_only=True, |
|
|
) |
|
|
|
|
|
image_paths = arkit_data["image_paths"] |
|
|
arkit_poses_c2w = arkit_data["arkit_poses_c2w"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ylff.coordinate_utils import convert_arkit_to_opencv |
|
|
|
|
|
arkit_poses_c2w_opencv = np.array([convert_arkit_to_opencv(p) for p in arkit_poses_c2w]) |
|
|
|
|
|
logger.info(f"Processed {len(image_paths)} frames") |
|
|
|
|
|
|
|
|
logger.info("\n=== Running DA3 Inference ===") |
|
|
model = load_da3_model("depth-anything/DA3-LARGE", device=args.device) |
|
|
|
|
|
import cv2 |
|
|
|
|
|
images = [] |
|
|
for img_path in image_paths: |
|
|
img = cv2.imread(str(img_path)) |
|
|
if img is not None: |
|
|
images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
logger.info(f"Running DA3 on {len(images)} images...") |
|
|
with torch.no_grad(): |
|
|
da3_output = model.inference(images) |
|
|
|
|
|
da3_poses = da3_output.extrinsics |
|
|
da3_intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None |
|
|
|
|
|
logger.info(f"DA3 poses: {da3_poses.shape}") |
|
|
|
|
|
|
|
|
|
|
|
arkit_poses_w2c_opencv = np.array([np.linalg.inv(p)[:3, :] for p in arkit_poses_c2w_opencv]) |
|
|
|
|
|
logger.info("\n=== Comparing DA3 vs ARKit (Ground Truth) ===") |
|
|
|
|
|
|
|
|
logger.info("\nPose Statistics (before alignment):") |
|
|
da3_centers = da3_poses[:, :3, 3] |
|
|
arkit_centers = arkit_poses_w2c_opencv[:, :3, 3] |
|
|
logger.info(f" DA3 translation range: [{da3_centers.min(axis=0)}, {da3_centers.max(axis=0)}]") |
|
|
da3_norms = np.linalg.norm(da3_centers, axis=1) |
|
|
logger.info( |
|
|
f" DA3 translation magnitude: mean={da3_norms.mean():.3f}, " f"std={da3_norms.std():.3f}" |
|
|
) |
|
|
logger.info( |
|
|
f" ARKit translation range: [{arkit_centers.min(axis=0)}, {arkit_centers.max(axis=0)}]" |
|
|
) |
|
|
arkit_norms = np.linalg.norm(arkit_centers, axis=1) |
|
|
logger.info( |
|
|
f" ARKit translation magnitude: mean={arkit_norms.mean():.3f}, " |
|
|
f"std={arkit_norms.std():.3f}" |
|
|
) |
|
|
|
|
|
da3_vs_arkit = compute_pose_error(da3_poses, arkit_poses_w2c_opencv, verbose=True) |
|
|
|
|
|
logger.info("\nDA3 vs ARKit Error Summary:") |
|
|
logger.info(f" Mean rotation error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°") |
|
|
logger.info(f" Median rotation error: {np.median(da3_vs_arkit['rotation_errors_deg']):.2f}°") |
|
|
logger.info(f" Max rotation error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°") |
|
|
logger.info(f" Min rotation error: {np.min(da3_vs_arkit['rotation_errors_deg']):.2f}°") |
|
|
logger.info(f" Std rotation error: {np.std(da3_vs_arkit['rotation_errors_deg']):.2f}°") |
|
|
logger.info(f" Mean translation error: {da3_vs_arkit['mean_translation_error']:.3f} m") |
|
|
logger.info(f" Max translation error: {np.max(da3_vs_arkit['translation_errors']):.3f} m") |
|
|
|
|
|
|
|
|
if "alignment_info" in da3_vs_arkit: |
|
|
align_info = da3_vs_arkit["alignment_info"] |
|
|
logger.info("\nAlignment Diagnostics:") |
|
|
logger.info( |
|
|
f" Scale factor: {align_info['scale_factor']:.6f} (should be ~1.0 if scales match)" |
|
|
) |
|
|
logger.info(f" Rotation matrix det: {align_info['rotation_det']:.6f} (should be ~1.0)") |
|
|
logger.info(f" Center1 (DA3) mean: {align_info['center1_mean']}") |
|
|
logger.info(f" Center2 (ARKit) mean: {align_info['center2_mean']}") |
|
|
|
|
|
|
|
|
logger.info("\nPer-Frame Error Breakdown:") |
|
|
logger.info(f" {'Frame':<8} {'Rot Error (°)':<15} {'Trans Error (m)':<15} {'Category':<20}") |
|
|
logger.info(f" {'-' * 8} {'-' * 15} {'-' * 15} {'-' * 20}") |
|
|
for i, (rot_err, trans_err) in enumerate( |
|
|
zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"]) |
|
|
): |
|
|
if rot_err < 2.0: |
|
|
category = "Accepted" |
|
|
elif rot_err < 30.0: |
|
|
category = "Rejected-Learnable" |
|
|
else: |
|
|
category = "Rejected-Outlier" |
|
|
logger.info(f" {i:<8} {rot_err:<15.2f} {trans_err:<15.3f} {category:<20}") |
|
|
|
|
|
|
|
|
rot_errors = da3_vs_arkit["rotation_errors_deg"] |
|
|
logger.info("\nError Distribution (Rotation):") |
|
|
logger.info(f" Q1 (25th percentile): {np.percentile(rot_errors, 25):.2f}°") |
|
|
logger.info(f" Q2 (50th percentile/median): {np.percentile(rot_errors, 50):.2f}°") |
|
|
logger.info(f" Q3 (75th percentile): {np.percentile(rot_errors, 75):.2f}°") |
|
|
logger.info(f" 90th percentile: {np.percentile(rot_errors, 90):.2f}°") |
|
|
logger.info(f" 95th percentile: {np.percentile(rot_errors, 95):.2f}°") |
|
|
logger.info(f" 99th percentile: {np.percentile(rot_errors, 99):.2f}°") |
|
|
|
|
|
|
|
|
logger.info("\n=== Running BA Validation ===") |
|
|
validator = BAValidator( |
|
|
accept_threshold=2.0, |
|
|
reject_threshold=30.0, |
|
|
work_dir=args.output_dir / "ba_work", |
|
|
) |
|
|
|
|
|
ba_result = validator.validate( |
|
|
images=images, |
|
|
poses_model=da3_poses, |
|
|
intrinsics=da3_intrinsics, |
|
|
) |
|
|
|
|
|
if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: |
|
|
ba_poses = ba_result["poses_ba"] |
|
|
|
|
|
|
|
|
logger.info("\n=== Comparing BA vs ARKit (Ground Truth) ===") |
|
|
ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c_opencv) |
|
|
|
|
|
logger.info("BA vs ARKit:") |
|
|
logger.info(f" Mean rotation error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°") |
|
|
logger.info(f" Max rotation error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°") |
|
|
logger.info(f" Mean translation error: {ba_vs_arkit['mean_translation_error']:.2f}") |
|
|
|
|
|
|
|
|
logger.info("\n=== Comparing DA3 vs BA ===") |
|
|
da3_vs_ba = compute_pose_error(da3_poses, ba_poses) |
|
|
|
|
|
logger.info("DA3 vs BA:") |
|
|
logger.info(f" Mean rotation error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°") |
|
|
logger.info(f" Max rotation error: {da3_vs_ba['max_rotation_error_deg']:.2f}°") |
|
|
|
|
|
|
|
|
np.save(args.output_dir / "da3_poses_w2c.npy", da3_poses) |
|
|
if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None: |
|
|
np.save(args.output_dir / "ba_poses_w2c.npy", ba_poses) |
|
|
|
|
|
|
|
|
rot_errors = da3_vs_arkit["rotation_errors_deg"] |
|
|
accepted_frames = [] |
|
|
rejected_learnable_frames = [] |
|
|
rejected_outlier_frames = [] |
|
|
|
|
|
accept_threshold = 2.0 |
|
|
reject_threshold = 30.0 |
|
|
|
|
|
for i, err in enumerate(rot_errors): |
|
|
if err < accept_threshold: |
|
|
accepted_frames.append(i) |
|
|
elif err < reject_threshold: |
|
|
rejected_learnable_frames.append(i) |
|
|
else: |
|
|
rejected_outlier_frames.append(i) |
|
|
|
|
|
frame_categorization = { |
|
|
"accepted": { |
|
|
"count": len(accepted_frames), |
|
|
"percentage": ( |
|
|
100.0 * len(accepted_frames) / len(rot_errors) if rot_errors else 0.0 |
|
|
), |
|
|
"frame_indices": accepted_frames, |
|
|
}, |
|
|
"rejected_learnable": { |
|
|
"count": len(rejected_learnable_frames), |
|
|
"percentage": ( |
|
|
100.0 * len(rejected_learnable_frames) / len(rot_errors) if rot_errors else 0.0 |
|
|
), |
|
|
"frame_indices": rejected_learnable_frames, |
|
|
}, |
|
|
"rejected_outlier": { |
|
|
"count": len(rejected_outlier_frames), |
|
|
"percentage": ( |
|
|
100.0 * len(rejected_outlier_frames) / len(rot_errors) if rot_errors else 0.0 |
|
|
), |
|
|
"frame_indices": rejected_outlier_frames, |
|
|
}, |
|
|
"total_frames": len(rot_errors), |
|
|
} |
|
|
|
|
|
logger.info("\n=== Frame Categorization (DA3 vs ARKit) ===") |
|
|
accepted_info = frame_categorization["accepted"] |
|
|
learnable_info = frame_categorization["rejected_learnable"] |
|
|
outlier_info = frame_categorization["rejected_outlier"] |
|
|
total_frames = frame_categorization["total_frames"] |
|
|
logger.info( |
|
|
f" Accepted (< {accept_threshold}°): " |
|
|
f"{accepted_info['count']}/{total_frames} " |
|
|
f"({accepted_info['percentage']:.1f}%)" |
|
|
) |
|
|
logger.info( |
|
|
f" Rejected-Learnable ({accept_threshold}-{reject_threshold}°): " |
|
|
f"{learnable_info['count']}/{total_frames} " |
|
|
f"({learnable_info['percentage']:.1f}%)" |
|
|
) |
|
|
logger.info( |
|
|
f" Rejected-Outlier (> {reject_threshold}°): " |
|
|
f"{outlier_info['count']}/{total_frames} " |
|
|
f"({outlier_info['percentage']:.1f}%)" |
|
|
) |
|
|
|
|
|
|
|
|
diagnostics = { |
|
|
"pose_statistics": { |
|
|
"da3": { |
|
|
"translation_range": { |
|
|
"min": da3_centers.min(axis=0).tolist(), |
|
|
"max": da3_centers.max(axis=0).tolist(), |
|
|
"mean_magnitude": float(np.linalg.norm(da3_centers, axis=1).mean()), |
|
|
"std_magnitude": float(np.linalg.norm(da3_centers, axis=1).std()), |
|
|
} |
|
|
}, |
|
|
"arkit": { |
|
|
"translation_range": { |
|
|
"min": arkit_centers.min(axis=0).tolist(), |
|
|
"max": arkit_centers.max(axis=0).tolist(), |
|
|
"mean_magnitude": float(np.linalg.norm(arkit_centers, axis=1).mean()), |
|
|
"std_magnitude": float(np.linalg.norm(arkit_centers, axis=1).std()), |
|
|
} |
|
|
}, |
|
|
}, |
|
|
"error_distribution": { |
|
|
"rotation_errors_deg": { |
|
|
"q1": float(np.percentile(rot_errors, 25)), |
|
|
"median": float(np.percentile(rot_errors, 50)), |
|
|
"q3": float(np.percentile(rot_errors, 75)), |
|
|
"p90": float(np.percentile(rot_errors, 90)), |
|
|
"p95": float(np.percentile(rot_errors, 95)), |
|
|
"p99": float(np.percentile(rot_errors, 99)), |
|
|
}, |
|
|
"translation_errors": { |
|
|
"mean": float(np.mean(da3_vs_arkit["translation_errors"])), |
|
|
"median": float(np.median(da3_vs_arkit["translation_errors"])), |
|
|
"max": float(np.max(da3_vs_arkit["translation_errors"])), |
|
|
"std": float(np.std(da3_vs_arkit["translation_errors"])), |
|
|
}, |
|
|
}, |
|
|
"per_frame_errors": [ |
|
|
{ |
|
|
"frame_idx": i, |
|
|
"rotation_error_deg": float(rot_err), |
|
|
"translation_error_m": float(trans_err), |
|
|
"category": ( |
|
|
"accepted" |
|
|
if rot_err < 2.0 |
|
|
else ("rejected_learnable" if rot_err < 30.0 else "rejected_outlier") |
|
|
), |
|
|
} |
|
|
for i, (rot_err, trans_err) in enumerate( |
|
|
zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"]) |
|
|
) |
|
|
], |
|
|
"da3_vs_arkit": {"alignment_info": da3_vs_arkit.get("alignment_info", {})}, |
|
|
} |
|
|
|
|
|
|
|
|
results = { |
|
|
"da3_vs_arkit": da3_vs_arkit, |
|
|
"ba_vs_arkit": ba_vs_arkit, |
|
|
"da3_vs_ba": da3_vs_ba, |
|
|
"ba_result": { |
|
|
"status": ba_result["status"], |
|
|
"error": ba_result.get("error"), |
|
|
"reprojection_error": ba_result.get("reprojection_error"), |
|
|
}, |
|
|
"frame_categorization": frame_categorization, |
|
|
"diagnostics": diagnostics, |
|
|
"num_frames": len(images), |
|
|
} |
|
|
|
|
|
results_path = args.output_dir / "validation_results.json" |
|
|
with open(results_path, "w") as f: |
|
|
json.dump(results, f, indent=2, default=str) |
|
|
|
|
|
logger.info(f"\n✓ Results saved to {results_path}") |
|
|
logger.info("✓ Poses saved for visualization") |
|
|
else: |
|
|
logger.warning("BA validation failed, skipping BA comparisons") |
|
|
|
|
|
logger.info("\n=== Complete ===") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|