#!/usr/bin/env python3 """ Visualize BA validation results for diagnostics. """ import json import sys from pathlib import Path from typing import Dict, Optional import cv2 import matplotlib.pyplot as plt import numpy as np # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) try: import plotly.graph_objects as go HAS_PLOTLY = True except ImportError: HAS_PLOTLY = False print("Plotly not available. Install with: pip install plotly") def load_results(results_path: Path) -> Dict: """Load validation results JSON.""" with open(results_path) as f: return json.load(f) def plot_trajectories_3d( arkit_poses: np.ndarray, da3_poses: np.ndarray, ba_poses: Optional[np.ndarray] = None, output_path: Path = None, use_plotly: bool = True, ): """ Plot 3D camera trajectories. Args: arkit_poses: (N, 4, 4) or (N, 3, 4) ARKit poses (c2w) da3_poses: (N, 3, 4) DA3 poses (w2c) ba_poses: (N, 3, 4) BA poses (w2c), optional output_path: Path to save figure use_plotly: Use plotly for interactive 3D (if available) """ # Convert to camera centers def get_centers(poses): if poses.shape[1] == 4: # 4x4 poses centers = poses[:, :3, 3] else: # 3x4 poses (w2c) - need to invert to get camera center centers = [] for pose in poses: R = pose[:3, :3] t = pose[:3, 3] # Camera center in world: -R^T @ t c = -R.T @ t centers.append(c) centers = np.array(centers) return centers arkit_centers = get_centers(arkit_poses) da3_centers = get_centers(da3_poses) if ba_poses is not None: ba_centers = get_centers(ba_poses) if use_plotly and HAS_PLOTLY: fig = go.Figure() # ARKit trajectory fig.add_trace( go.Scatter3d( x=arkit_centers[:, 0], y=arkit_centers[:, 1], z=arkit_centers[:, 2], mode="lines+markers", name="ARKit (Ground Truth)", line=dict(color="green", width=4), marker=dict(size=4), ) ) # DA3 trajectory fig.add_trace( go.Scatter3d( x=da3_centers[:, 0], y=da3_centers[:, 1], z=da3_centers[:, 2], mode="lines+markers", name="DA3", line=dict(color="red", width=2), marker=dict(size=3), ) ) # BA trajectory if ba_poses is not None: fig.add_trace( go.Scatter3d( x=ba_centers[:, 0], y=ba_centers[:, 1], z=ba_centers[:, 2], mode="lines+markers", name="BA", line=dict(color="blue", width=2), marker=dict(size=3), ) ) fig.update_layout( title="Camera Trajectories (3D)", scene=dict( xaxis_title="X (m)", yaxis_title="Y (m)", zaxis_title="Z (m)", aspectmode="data", ), width=1000, height=800, ) if output_path: fig.write_html(str(output_path)) print(f"Saved interactive plot to {output_path}") else: fig.show() else: # Fallback to matplotlib fig = plt.figure(figsize=(12, 10)) ax = fig.add_subplot(111, projection="3d") ax.plot( arkit_centers[:, 0], arkit_centers[:, 1], arkit_centers[:, 2], "g-", linewidth=2, marker="o", markersize=4, label="ARKit (GT)", ) ax.plot( da3_centers[:, 0], da3_centers[:, 1], da3_centers[:, 2], "r-", linewidth=1, marker="s", markersize=3, label="DA3", ) if ba_poses is not None: ax.plot( ba_centers[:, 0], ba_centers[:, 1], ba_centers[:, 2], "b-", linewidth=1, marker="^", markersize=3, label="BA", ) ax.set_xlabel("X (m)") ax.set_ylabel("Y (m)") ax.set_zlabel("Z (m)") ax.set_title("Camera Trajectories (3D)") ax.legend() ax.grid(True) if output_path: plt.savefig(output_path, dpi=150, bbox_inches="tight") print(f"Saved plot to {output_path}") else: plt.show() plt.close() def plot_error_metrics(results: Dict, output_dir: Path): """Plot rotation and translation errors.""" fig, axes = plt.subplots(2, 2, figsize=(15, 10)) # Rotation errors: DA3 vs ARKit ax = axes[0, 0] da3_errors = results["da3_vs_arkit"]["rotation_errors_deg"] ax.plot(da3_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit") ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)") ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)") ax.set_xlabel("Frame Index") ax.set_ylabel("Rotation Error (degrees)") ax.set_title("DA3 vs ARKit: Rotation Error") ax.legend() ax.grid(True, alpha=0.3) # Rotation errors: BA vs ARKit ax = axes[0, 1] if "ba_vs_arkit" in results: ba_errors = results["ba_vs_arkit"]["rotation_errors_deg"] ax.plot(ba_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit") ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)") ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)") ax.set_xlabel("Frame Index") ax.set_ylabel("Rotation Error (degrees)") ax.set_title("BA vs ARKit: Rotation Error") ax.legend() ax.grid(True, alpha=0.3) # Translation errors: DA3 vs ARKit ax = axes[1, 0] da3_trans_errors = results["da3_vs_arkit"]["translation_errors"] ax.plot(da3_trans_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit") ax.set_xlabel("Frame Index") ax.set_ylabel("Translation Error (m)") ax.set_title("DA3 vs ARKit: Translation Error") ax.legend() ax.grid(True, alpha=0.3) # Translation errors: BA vs ARKit ax = axes[1, 1] if "ba_vs_arkit" in results: ba_trans_errors = results["ba_vs_arkit"]["translation_errors"] ax.plot(ba_trans_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit") ax.set_xlabel("Frame Index") ax.set_ylabel("Translation Error (m)") ax.set_title("BA vs ARKit: Translation Error") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() output_path = output_dir / "error_metrics.png" plt.savefig(output_path, dpi=150, bbox_inches="tight") print(f"Saved error metrics to {output_path}") plt.close() def plot_error_comparison(results: Dict, output_dir: Path): """Plot side-by-side comparison of errors.""" fig, axes = plt.subplots(1, 2, figsize=(15, 5)) # Rotation errors comparison ax = axes[0] frames = np.arange(len(results["da3_vs_arkit"]["rotation_errors_deg"])) ax.plot( frames, results["da3_vs_arkit"]["rotation_errors_deg"], "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit", ) if "ba_vs_arkit" in results: ax.plot( frames, results["ba_vs_arkit"]["rotation_errors_deg"], "b-o", linewidth=2, markersize=6, label="BA vs ARKit", ) ax.axhline(y=2.0, color="g", linestyle="--", alpha=0.5, label="Accept (2°)") ax.axhline(y=30.0, color="orange", linestyle="--", alpha=0.5, label="Reject (30°)") ax.set_xlabel("Frame Index") ax.set_ylabel("Rotation Error (degrees)") ax.set_title("Rotation Error Comparison") ax.legend() ax.grid(True, alpha=0.3) # Translation errors comparison ax = axes[1] ax.plot( frames, results["da3_vs_arkit"]["translation_errors"], "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit", ) if "ba_vs_arkit" in results: ax.plot( frames, results["ba_vs_arkit"]["translation_errors"], "b-o", linewidth=2, markersize=6, label="BA vs ARKit", ) ax.set_xlabel("Frame Index") ax.set_ylabel("Translation Error (m)") ax.set_title("Translation Error Comparison") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() output_path = output_dir / "error_comparison.png" plt.savefig(output_path, dpi=150, bbox_inches="tight") print(f"Saved error comparison to {output_path}") plt.close() def visualize_matches( image_path1: Path, image_path2: Path, matches_path: Path, features_path: Path, output_path: Path, ): """Visualize feature matches between two images.""" import h5py # Load images img1 = cv2.imread(str(image_path1)) img2 = cv2.imread(str(image_path2)) if img1 is None or img2 is None: print(f"Could not load images: {image_path1}, {image_path2}") return img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) # Load features and matches with h5py.File(features_path, "r") as f: kp1 = f[Path(image_path1).name]["keypoints"][:] kp2 = f[Path(image_path2).name]["keypoints"][:] with h5py.File(matches_path, "r") as f: pair_name = f"{Path(image_path1).name} {Path(image_path2).name}" if pair_name in f: matches = f[pair_name]["matches0"][:] else: # Try reverse order pair_name = f"{Path(image_path2).name} {Path(image_path1).name}" if pair_name in f: matches = f[pair_name]["matches0"][:] else: print( f"No matches found for pair: {Path(image_path1).name} " f"<-> {Path(image_path2).name}" ) return # Filter valid matches valid = matches > -1 matches1 = np.where(valid)[0] matches2 = matches[valid] # Draw matches h1, w1 = img1.shape[:2] h2, w2 = img2.shape[:2] vis = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8) vis[:h1, :w1] = img1 vis[:h2, w1:] = img2 # Draw keypoints and matches for i, (m1, m2) in enumerate(zip(matches1, matches2)): pt1 = tuple(kp1[m1].astype(int)) pt2 = tuple((kp2[m2] + [w1, 0]).astype(int)) color = np.random.randint(0, 255, 3).tolist() cv2.circle(vis, pt1, 5, color, -1) cv2.circle(vis, pt2, 5, color, -1) cv2.line(vis, pt1, pt2, color, 1) # Save vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) cv2.imwrite(str(output_path), vis_bgr) print(f"Saved match visualization to {output_path}") def create_summary_report(results: Dict, output_dir: Path): """Create a text summary report.""" report_path = output_dir / "summary_report.txt" with open(report_path, "w") as f: f.write("=" * 60 + "\n") f.write("BA Validation Summary Report\n") f.write("=" * 60 + "\n\n") f.write(f"Total Frames: {results.get('num_frames', 'N/A')}\n\n") # DA3 vs ARKit f.write("DA3 vs ARKit (Ground Truth):\n") f.write("-" * 40 + "\n") da3_vs_arkit = results["da3_vs_arkit"] f.write(f" Mean Rotation Error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°\n") f.write(f" Max Rotation Error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°\n") f.write(f" Mean Translation Error: {da3_vs_arkit['mean_translation_error']:.4f} m\n\n") # BA vs ARKit if "ba_vs_arkit" in results: f.write("BA vs ARKit (Ground Truth):\n") f.write("-" * 40 + "\n") ba_vs_arkit = results["ba_vs_arkit"] f.write(f" Mean Rotation Error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°\n") f.write(f" Max Rotation Error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°\n") f.write(f" Mean Translation Error: {ba_vs_arkit['mean_translation_error']:.4f} m\n\n") # DA3 vs BA if "da3_vs_ba" in results: f.write("DA3 vs BA:\n") f.write("-" * 40 + "\n") da3_vs_ba = results["da3_vs_ba"] f.write(f" Mean Rotation Error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°\n") f.write(f" Max Rotation Error: {da3_vs_ba['max_rotation_error_deg']:.2f}°\n") f.write(f" Mean Translation Error: {da3_vs_ba['mean_translation_error']:.4f} m\n\n") # BA Result if "ba_result" in results: f.write("BA Validation Result:\n") f.write("-" * 40 + "\n") ba_result = results["ba_result"] f.write(f" Status: {ba_result.get('status', 'N/A')}\n") f.write(f" Error: {ba_result.get('error', 'N/A')}\n") f.write(f" Reprojection Error: {ba_result.get('reprojection_error', 'N/A')}\n\n") # Categorization if "da3_vs_arkit" in results: errors = results["da3_vs_arkit"]["rotation_errors_deg"] accepted = sum(1 for e in errors if e < 2.0) learnable = sum(1 for e in errors if 2.0 <= e < 30.0) outlier = sum(1 for e in errors if e >= 30.0) f.write("Frame Categorization (DA3 vs ARKit):\n") f.write("-" * 40 + "\n") accepted_pct = 100 * accepted / len(errors) learnable_pct = 100 * learnable / len(errors) outlier_pct = 100 * outlier / len(errors) f.write(f" Accepted (< 2°): {accepted}/{len(errors)} " f"({accepted_pct:.1f}%)\n") f.write(f" Learnable (2-30°): {learnable}/{len(errors)} " f"({learnable_pct:.1f}%)\n") f.write(f" Outlier (> 30°): {outlier}/{len(errors)} " f"({outlier_pct:.1f}%)\n") print(f"Saved summary report to {report_path}") def main(): import argparse parser = argparse.ArgumentParser(description="Visualize BA validation results") parser.add_argument( "--results-dir", type=Path, default=project_root / "data" / "arkit_ba_validation", help="Directory containing validation results", ) parser.add_argument( "--output-dir", type=Path, default=None, help="Output directory for visualizations (default: results_dir/visualizations)", ) parser.add_argument( "--use-plotly", action="store_true", help="Use plotly for interactive 3D plots (if available)", ) args = parser.parse_args() results_path = args.results_dir / "validation_results.json" if not results_path.exists(): print(f"Results file not found: {results_path}") return output_dir = args.output_dir or (args.results_dir / "visualizations") output_dir.mkdir(parents=True, exist_ok=True) print(f"Loading results from {results_path}") results = load_results(results_path) # Load poses arkit_poses_path = args.results_dir / "arkit_poses_c2w.npy" da3_poses_path = args.results_dir / "da3_poses_w2c.npy" ba_poses_path = args.results_dir / "ba_poses_w2c.npy" arkit_poses = None da3_poses = None ba_poses = None if arkit_poses_path.exists(): arkit_poses = np.load(arkit_poses_path) print(f"Loaded ARKit poses: {arkit_poses.shape}") if da3_poses_path.exists(): da3_poses = np.load(da3_poses_path) print(f"Loaded DA3 poses: {da3_poses.shape}") if ba_poses_path.exists(): ba_poses = np.load(ba_poses_path) print(f"Loaded BA poses: {ba_poses.shape}") # Create visualizations print("\nCreating visualizations...") # Error metrics plot_error_metrics(results, output_dir) plot_error_comparison(results, output_dir) # Summary report create_summary_report(results, output_dir) # Trajectory plot (if poses available) if arkit_poses is not None and da3_poses is not None: try: plot_trajectories_3d( arkit_poses, da3_poses, ba_poses=ba_poses, output_path=( output_dir / "trajectories_3d.html" if (args.use_plotly and HAS_PLOTLY) else output_dir / "trajectories_3d.png" ), use_plotly=args.use_plotly and HAS_PLOTLY, ) except Exception as e: print(f"Error creating trajectory plot: {e}") import traceback traceback.print_exc() else: print("Skipping trajectory visualization (poses not available)") print(f"\n✓ Visualizations saved to {output_dir}") if __name__ == "__main__": main()