3d_model / scripts /tools /visualize_ba_results.py
Azan
Clean deployment build (Squashed)
7a87926
#!/usr/bin/env python3
"""
Visualize BA validation results for diagnostics.
"""
import json
import sys
from pathlib import Path
from typing import Dict, Optional
import cv2
import matplotlib.pyplot as plt
import numpy as np
# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
try:
import plotly.graph_objects as go
HAS_PLOTLY = True
except ImportError:
HAS_PLOTLY = False
print("Plotly not available. Install with: pip install plotly")
def load_results(results_path: Path) -> Dict:
"""Load validation results JSON."""
with open(results_path) as f:
return json.load(f)
def plot_trajectories_3d(
arkit_poses: np.ndarray,
da3_poses: np.ndarray,
ba_poses: Optional[np.ndarray] = None,
output_path: Path = None,
use_plotly: bool = True,
):
"""
Plot 3D camera trajectories.
Args:
arkit_poses: (N, 4, 4) or (N, 3, 4) ARKit poses (c2w)
da3_poses: (N, 3, 4) DA3 poses (w2c)
ba_poses: (N, 3, 4) BA poses (w2c), optional
output_path: Path to save figure
use_plotly: Use plotly for interactive 3D (if available)
"""
# Convert to camera centers
def get_centers(poses):
if poses.shape[1] == 4:
# 4x4 poses
centers = poses[:, :3, 3]
else:
# 3x4 poses (w2c) - need to invert to get camera center
centers = []
for pose in poses:
R = pose[:3, :3]
t = pose[:3, 3]
# Camera center in world: -R^T @ t
c = -R.T @ t
centers.append(c)
centers = np.array(centers)
return centers
arkit_centers = get_centers(arkit_poses)
da3_centers = get_centers(da3_poses)
if ba_poses is not None:
ba_centers = get_centers(ba_poses)
if use_plotly and HAS_PLOTLY:
fig = go.Figure()
# ARKit trajectory
fig.add_trace(
go.Scatter3d(
x=arkit_centers[:, 0],
y=arkit_centers[:, 1],
z=arkit_centers[:, 2],
mode="lines+markers",
name="ARKit (Ground Truth)",
line=dict(color="green", width=4),
marker=dict(size=4),
)
)
# DA3 trajectory
fig.add_trace(
go.Scatter3d(
x=da3_centers[:, 0],
y=da3_centers[:, 1],
z=da3_centers[:, 2],
mode="lines+markers",
name="DA3",
line=dict(color="red", width=2),
marker=dict(size=3),
)
)
# BA trajectory
if ba_poses is not None:
fig.add_trace(
go.Scatter3d(
x=ba_centers[:, 0],
y=ba_centers[:, 1],
z=ba_centers[:, 2],
mode="lines+markers",
name="BA",
line=dict(color="blue", width=2),
marker=dict(size=3),
)
)
fig.update_layout(
title="Camera Trajectories (3D)",
scene=dict(
xaxis_title="X (m)",
yaxis_title="Y (m)",
zaxis_title="Z (m)",
aspectmode="data",
),
width=1000,
height=800,
)
if output_path:
fig.write_html(str(output_path))
print(f"Saved interactive plot to {output_path}")
else:
fig.show()
else:
# Fallback to matplotlib
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection="3d")
ax.plot(
arkit_centers[:, 0],
arkit_centers[:, 1],
arkit_centers[:, 2],
"g-",
linewidth=2,
marker="o",
markersize=4,
label="ARKit (GT)",
)
ax.plot(
da3_centers[:, 0],
da3_centers[:, 1],
da3_centers[:, 2],
"r-",
linewidth=1,
marker="s",
markersize=3,
label="DA3",
)
if ba_poses is not None:
ax.plot(
ba_centers[:, 0],
ba_centers[:, 1],
ba_centers[:, 2],
"b-",
linewidth=1,
marker="^",
markersize=3,
label="BA",
)
ax.set_xlabel("X (m)")
ax.set_ylabel("Y (m)")
ax.set_zlabel("Z (m)")
ax.set_title("Camera Trajectories (3D)")
ax.legend()
ax.grid(True)
if output_path:
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"Saved plot to {output_path}")
else:
plt.show()
plt.close()
def plot_error_metrics(results: Dict, output_dir: Path):
"""Plot rotation and translation errors."""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Rotation errors: DA3 vs ARKit
ax = axes[0, 0]
da3_errors = results["da3_vs_arkit"]["rotation_errors_deg"]
ax.plot(da3_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit")
ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)")
ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Rotation Error (degrees)")
ax.set_title("DA3 vs ARKit: Rotation Error")
ax.legend()
ax.grid(True, alpha=0.3)
# Rotation errors: BA vs ARKit
ax = axes[0, 1]
if "ba_vs_arkit" in results:
ba_errors = results["ba_vs_arkit"]["rotation_errors_deg"]
ax.plot(ba_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit")
ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)")
ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Rotation Error (degrees)")
ax.set_title("BA vs ARKit: Rotation Error")
ax.legend()
ax.grid(True, alpha=0.3)
# Translation errors: DA3 vs ARKit
ax = axes[1, 0]
da3_trans_errors = results["da3_vs_arkit"]["translation_errors"]
ax.plot(da3_trans_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Translation Error (m)")
ax.set_title("DA3 vs ARKit: Translation Error")
ax.legend()
ax.grid(True, alpha=0.3)
# Translation errors: BA vs ARKit
ax = axes[1, 1]
if "ba_vs_arkit" in results:
ba_trans_errors = results["ba_vs_arkit"]["translation_errors"]
ax.plot(ba_trans_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Translation Error (m)")
ax.set_title("BA vs ARKit: Translation Error")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
output_path = output_dir / "error_metrics.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"Saved error metrics to {output_path}")
plt.close()
def plot_error_comparison(results: Dict, output_dir: Path):
"""Plot side-by-side comparison of errors."""
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Rotation errors comparison
ax = axes[0]
frames = np.arange(len(results["da3_vs_arkit"]["rotation_errors_deg"]))
ax.plot(
frames,
results["da3_vs_arkit"]["rotation_errors_deg"],
"r-o",
linewidth=2,
markersize=6,
label="DA3 vs ARKit",
)
if "ba_vs_arkit" in results:
ax.plot(
frames,
results["ba_vs_arkit"]["rotation_errors_deg"],
"b-o",
linewidth=2,
markersize=6,
label="BA vs ARKit",
)
ax.axhline(y=2.0, color="g", linestyle="--", alpha=0.5, label="Accept (2°)")
ax.axhline(y=30.0, color="orange", linestyle="--", alpha=0.5, label="Reject (30°)")
ax.set_xlabel("Frame Index")
ax.set_ylabel("Rotation Error (degrees)")
ax.set_title("Rotation Error Comparison")
ax.legend()
ax.grid(True, alpha=0.3)
# Translation errors comparison
ax = axes[1]
ax.plot(
frames,
results["da3_vs_arkit"]["translation_errors"],
"r-o",
linewidth=2,
markersize=6,
label="DA3 vs ARKit",
)
if "ba_vs_arkit" in results:
ax.plot(
frames,
results["ba_vs_arkit"]["translation_errors"],
"b-o",
linewidth=2,
markersize=6,
label="BA vs ARKit",
)
ax.set_xlabel("Frame Index")
ax.set_ylabel("Translation Error (m)")
ax.set_title("Translation Error Comparison")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
output_path = output_dir / "error_comparison.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"Saved error comparison to {output_path}")
plt.close()
def visualize_matches(
image_path1: Path,
image_path2: Path,
matches_path: Path,
features_path: Path,
output_path: Path,
):
"""Visualize feature matches between two images."""
import h5py
# Load images
img1 = cv2.imread(str(image_path1))
img2 = cv2.imread(str(image_path2))
if img1 is None or img2 is None:
print(f"Could not load images: {image_path1}, {image_path2}")
return
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
# Load features and matches
with h5py.File(features_path, "r") as f:
kp1 = f[Path(image_path1).name]["keypoints"][:]
kp2 = f[Path(image_path2).name]["keypoints"][:]
with h5py.File(matches_path, "r") as f:
pair_name = f"{Path(image_path1).name} {Path(image_path2).name}"
if pair_name in f:
matches = f[pair_name]["matches0"][:]
else:
# Try reverse order
pair_name = f"{Path(image_path2).name} {Path(image_path1).name}"
if pair_name in f:
matches = f[pair_name]["matches0"][:]
else:
print(
f"No matches found for pair: {Path(image_path1).name} "
f"<-> {Path(image_path2).name}"
)
return
# Filter valid matches
valid = matches > -1
matches1 = np.where(valid)[0]
matches2 = matches[valid]
# Draw matches
h1, w1 = img1.shape[:2]
h2, w2 = img2.shape[:2]
vis = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8)
vis[:h1, :w1] = img1
vis[:h2, w1:] = img2
# Draw keypoints and matches
for i, (m1, m2) in enumerate(zip(matches1, matches2)):
pt1 = tuple(kp1[m1].astype(int))
pt2 = tuple((kp2[m2] + [w1, 0]).astype(int))
color = np.random.randint(0, 255, 3).tolist()
cv2.circle(vis, pt1, 5, color, -1)
cv2.circle(vis, pt2, 5, color, -1)
cv2.line(vis, pt1, pt2, color, 1)
# Save
vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(output_path), vis_bgr)
print(f"Saved match visualization to {output_path}")
def create_summary_report(results: Dict, output_dir: Path):
"""Create a text summary report."""
report_path = output_dir / "summary_report.txt"
with open(report_path, "w") as f:
f.write("=" * 60 + "\n")
f.write("BA Validation Summary Report\n")
f.write("=" * 60 + "\n\n")
f.write(f"Total Frames: {results.get('num_frames', 'N/A')}\n\n")
# DA3 vs ARKit
f.write("DA3 vs ARKit (Ground Truth):\n")
f.write("-" * 40 + "\n")
da3_vs_arkit = results["da3_vs_arkit"]
f.write(f" Mean Rotation Error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°\n")
f.write(f" Max Rotation Error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°\n")
f.write(f" Mean Translation Error: {da3_vs_arkit['mean_translation_error']:.4f} m\n\n")
# BA vs ARKit
if "ba_vs_arkit" in results:
f.write("BA vs ARKit (Ground Truth):\n")
f.write("-" * 40 + "\n")
ba_vs_arkit = results["ba_vs_arkit"]
f.write(f" Mean Rotation Error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°\n")
f.write(f" Max Rotation Error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°\n")
f.write(f" Mean Translation Error: {ba_vs_arkit['mean_translation_error']:.4f} m\n\n")
# DA3 vs BA
if "da3_vs_ba" in results:
f.write("DA3 vs BA:\n")
f.write("-" * 40 + "\n")
da3_vs_ba = results["da3_vs_ba"]
f.write(f" Mean Rotation Error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°\n")
f.write(f" Max Rotation Error: {da3_vs_ba['max_rotation_error_deg']:.2f}°\n")
f.write(f" Mean Translation Error: {da3_vs_ba['mean_translation_error']:.4f} m\n\n")
# BA Result
if "ba_result" in results:
f.write("BA Validation Result:\n")
f.write("-" * 40 + "\n")
ba_result = results["ba_result"]
f.write(f" Status: {ba_result.get('status', 'N/A')}\n")
f.write(f" Error: {ba_result.get('error', 'N/A')}\n")
f.write(f" Reprojection Error: {ba_result.get('reprojection_error', 'N/A')}\n\n")
# Categorization
if "da3_vs_arkit" in results:
errors = results["da3_vs_arkit"]["rotation_errors_deg"]
accepted = sum(1 for e in errors if e < 2.0)
learnable = sum(1 for e in errors if 2.0 <= e < 30.0)
outlier = sum(1 for e in errors if e >= 30.0)
f.write("Frame Categorization (DA3 vs ARKit):\n")
f.write("-" * 40 + "\n")
accepted_pct = 100 * accepted / len(errors)
learnable_pct = 100 * learnable / len(errors)
outlier_pct = 100 * outlier / len(errors)
f.write(f" Accepted (< 2°): {accepted}/{len(errors)} " f"({accepted_pct:.1f}%)\n")
f.write(f" Learnable (2-30°): {learnable}/{len(errors)} " f"({learnable_pct:.1f}%)\n")
f.write(f" Outlier (> 30°): {outlier}/{len(errors)} " f"({outlier_pct:.1f}%)\n")
print(f"Saved summary report to {report_path}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Visualize BA validation results")
parser.add_argument(
"--results-dir",
type=Path,
default=project_root / "data" / "arkit_ba_validation",
help="Directory containing validation results",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Output directory for visualizations (default: results_dir/visualizations)",
)
parser.add_argument(
"--use-plotly",
action="store_true",
help="Use plotly for interactive 3D plots (if available)",
)
args = parser.parse_args()
results_path = args.results_dir / "validation_results.json"
if not results_path.exists():
print(f"Results file not found: {results_path}")
return
output_dir = args.output_dir or (args.results_dir / "visualizations")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Loading results from {results_path}")
results = load_results(results_path)
# Load poses
arkit_poses_path = args.results_dir / "arkit_poses_c2w.npy"
da3_poses_path = args.results_dir / "da3_poses_w2c.npy"
ba_poses_path = args.results_dir / "ba_poses_w2c.npy"
arkit_poses = None
da3_poses = None
ba_poses = None
if arkit_poses_path.exists():
arkit_poses = np.load(arkit_poses_path)
print(f"Loaded ARKit poses: {arkit_poses.shape}")
if da3_poses_path.exists():
da3_poses = np.load(da3_poses_path)
print(f"Loaded DA3 poses: {da3_poses.shape}")
if ba_poses_path.exists():
ba_poses = np.load(ba_poses_path)
print(f"Loaded BA poses: {ba_poses.shape}")
# Create visualizations
print("\nCreating visualizations...")
# Error metrics
plot_error_metrics(results, output_dir)
plot_error_comparison(results, output_dir)
# Summary report
create_summary_report(results, output_dir)
# Trajectory plot (if poses available)
if arkit_poses is not None and da3_poses is not None:
try:
plot_trajectories_3d(
arkit_poses,
da3_poses,
ba_poses=ba_poses,
output_path=(
output_dir / "trajectories_3d.html"
if (args.use_plotly and HAS_PLOTLY)
else output_dir / "trajectories_3d.png"
),
use_plotly=args.use_plotly and HAS_PLOTLY,
)
except Exception as e:
print(f"Error creating trajectory plot: {e}")
import traceback
traceback.print_exc()
else:
print("Skipping trajectory visualization (poses not available)")
print(f"\n✓ Visualizations saved to {output_dir}")
if __name__ == "__main__":
main()