|
|
|
|
|
""" |
|
|
Visualize BA validation results for diagnostics. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
import cv2 |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
project_root = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(project_root)) |
|
|
|
|
|
try: |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
HAS_PLOTLY = True |
|
|
except ImportError: |
|
|
HAS_PLOTLY = False |
|
|
print("Plotly not available. Install with: pip install plotly") |
|
|
|
|
|
|
|
|
def load_results(results_path: Path) -> Dict: |
|
|
"""Load validation results JSON.""" |
|
|
with open(results_path) as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def plot_trajectories_3d( |
|
|
arkit_poses: np.ndarray, |
|
|
da3_poses: np.ndarray, |
|
|
ba_poses: Optional[np.ndarray] = None, |
|
|
output_path: Path = None, |
|
|
use_plotly: bool = True, |
|
|
): |
|
|
""" |
|
|
Plot 3D camera trajectories. |
|
|
|
|
|
Args: |
|
|
arkit_poses: (N, 4, 4) or (N, 3, 4) ARKit poses (c2w) |
|
|
da3_poses: (N, 3, 4) DA3 poses (w2c) |
|
|
ba_poses: (N, 3, 4) BA poses (w2c), optional |
|
|
output_path: Path to save figure |
|
|
use_plotly: Use plotly for interactive 3D (if available) |
|
|
""" |
|
|
|
|
|
|
|
|
def get_centers(poses): |
|
|
if poses.shape[1] == 4: |
|
|
|
|
|
centers = poses[:, :3, 3] |
|
|
else: |
|
|
|
|
|
centers = [] |
|
|
for pose in poses: |
|
|
R = pose[:3, :3] |
|
|
t = pose[:3, 3] |
|
|
|
|
|
c = -R.T @ t |
|
|
centers.append(c) |
|
|
centers = np.array(centers) |
|
|
return centers |
|
|
|
|
|
arkit_centers = get_centers(arkit_poses) |
|
|
da3_centers = get_centers(da3_poses) |
|
|
|
|
|
if ba_poses is not None: |
|
|
ba_centers = get_centers(ba_poses) |
|
|
|
|
|
if use_plotly and HAS_PLOTLY: |
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=arkit_centers[:, 0], |
|
|
y=arkit_centers[:, 1], |
|
|
z=arkit_centers[:, 2], |
|
|
mode="lines+markers", |
|
|
name="ARKit (Ground Truth)", |
|
|
line=dict(color="green", width=4), |
|
|
marker=dict(size=4), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=da3_centers[:, 0], |
|
|
y=da3_centers[:, 1], |
|
|
z=da3_centers[:, 2], |
|
|
mode="lines+markers", |
|
|
name="DA3", |
|
|
line=dict(color="red", width=2), |
|
|
marker=dict(size=3), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if ba_poses is not None: |
|
|
fig.add_trace( |
|
|
go.Scatter3d( |
|
|
x=ba_centers[:, 0], |
|
|
y=ba_centers[:, 1], |
|
|
z=ba_centers[:, 2], |
|
|
mode="lines+markers", |
|
|
name="BA", |
|
|
line=dict(color="blue", width=2), |
|
|
marker=dict(size=3), |
|
|
) |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
title="Camera Trajectories (3D)", |
|
|
scene=dict( |
|
|
xaxis_title="X (m)", |
|
|
yaxis_title="Y (m)", |
|
|
zaxis_title="Z (m)", |
|
|
aspectmode="data", |
|
|
), |
|
|
width=1000, |
|
|
height=800, |
|
|
) |
|
|
|
|
|
if output_path: |
|
|
fig.write_html(str(output_path)) |
|
|
print(f"Saved interactive plot to {output_path}") |
|
|
else: |
|
|
fig.show() |
|
|
else: |
|
|
|
|
|
fig = plt.figure(figsize=(12, 10)) |
|
|
ax = fig.add_subplot(111, projection="3d") |
|
|
|
|
|
ax.plot( |
|
|
arkit_centers[:, 0], |
|
|
arkit_centers[:, 1], |
|
|
arkit_centers[:, 2], |
|
|
"g-", |
|
|
linewidth=2, |
|
|
marker="o", |
|
|
markersize=4, |
|
|
label="ARKit (GT)", |
|
|
) |
|
|
ax.plot( |
|
|
da3_centers[:, 0], |
|
|
da3_centers[:, 1], |
|
|
da3_centers[:, 2], |
|
|
"r-", |
|
|
linewidth=1, |
|
|
marker="s", |
|
|
markersize=3, |
|
|
label="DA3", |
|
|
) |
|
|
|
|
|
if ba_poses is not None: |
|
|
ax.plot( |
|
|
ba_centers[:, 0], |
|
|
ba_centers[:, 1], |
|
|
ba_centers[:, 2], |
|
|
"b-", |
|
|
linewidth=1, |
|
|
marker="^", |
|
|
markersize=3, |
|
|
label="BA", |
|
|
) |
|
|
|
|
|
ax.set_xlabel("X (m)") |
|
|
ax.set_ylabel("Y (m)") |
|
|
ax.set_zlabel("Z (m)") |
|
|
ax.set_title("Camera Trajectories (3D)") |
|
|
ax.legend() |
|
|
ax.grid(True) |
|
|
|
|
|
if output_path: |
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight") |
|
|
print(f"Saved plot to {output_path}") |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
def plot_error_metrics(results: Dict, output_dir: Path): |
|
|
"""Plot rotation and translation errors.""" |
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
|
|
|
|
|
|
|
|
ax = axes[0, 0] |
|
|
da3_errors = results["da3_vs_arkit"]["rotation_errors_deg"] |
|
|
ax.plot(da3_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit") |
|
|
ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)") |
|
|
ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)") |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Rotation Error (degrees)") |
|
|
ax.set_title("DA3 vs ARKit: Rotation Error") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[0, 1] |
|
|
if "ba_vs_arkit" in results: |
|
|
ba_errors = results["ba_vs_arkit"]["rotation_errors_deg"] |
|
|
ax.plot(ba_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit") |
|
|
ax.axhline(y=2.0, color="g", linestyle="--", label="Accept threshold (2°)") |
|
|
ax.axhline(y=30.0, color="orange", linestyle="--", label="Reject threshold (30°)") |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Rotation Error (degrees)") |
|
|
ax.set_title("BA vs ARKit: Rotation Error") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1, 0] |
|
|
da3_trans_errors = results["da3_vs_arkit"]["translation_errors"] |
|
|
ax.plot(da3_trans_errors, "r-o", linewidth=2, markersize=6, label="DA3 vs ARKit") |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Translation Error (m)") |
|
|
ax.set_title("DA3 vs ARKit: Translation Error") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1, 1] |
|
|
if "ba_vs_arkit" in results: |
|
|
ba_trans_errors = results["ba_vs_arkit"]["translation_errors"] |
|
|
ax.plot(ba_trans_errors, "b-o", linewidth=2, markersize=6, label="BA vs ARKit") |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Translation Error (m)") |
|
|
ax.set_title("BA vs ARKit: Translation Error") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = output_dir / "error_metrics.png" |
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight") |
|
|
print(f"Saved error metrics to {output_path}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def plot_error_comparison(results: Dict, output_dir: Path): |
|
|
"""Plot side-by-side comparison of errors.""" |
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
frames = np.arange(len(results["da3_vs_arkit"]["rotation_errors_deg"])) |
|
|
ax.plot( |
|
|
frames, |
|
|
results["da3_vs_arkit"]["rotation_errors_deg"], |
|
|
"r-o", |
|
|
linewidth=2, |
|
|
markersize=6, |
|
|
label="DA3 vs ARKit", |
|
|
) |
|
|
if "ba_vs_arkit" in results: |
|
|
ax.plot( |
|
|
frames, |
|
|
results["ba_vs_arkit"]["rotation_errors_deg"], |
|
|
"b-o", |
|
|
linewidth=2, |
|
|
markersize=6, |
|
|
label="BA vs ARKit", |
|
|
) |
|
|
ax.axhline(y=2.0, color="g", linestyle="--", alpha=0.5, label="Accept (2°)") |
|
|
ax.axhline(y=30.0, color="orange", linestyle="--", alpha=0.5, label="Reject (30°)") |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Rotation Error (degrees)") |
|
|
ax.set_title("Rotation Error Comparison") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.plot( |
|
|
frames, |
|
|
results["da3_vs_arkit"]["translation_errors"], |
|
|
"r-o", |
|
|
linewidth=2, |
|
|
markersize=6, |
|
|
label="DA3 vs ARKit", |
|
|
) |
|
|
if "ba_vs_arkit" in results: |
|
|
ax.plot( |
|
|
frames, |
|
|
results["ba_vs_arkit"]["translation_errors"], |
|
|
"b-o", |
|
|
linewidth=2, |
|
|
markersize=6, |
|
|
label="BA vs ARKit", |
|
|
) |
|
|
ax.set_xlabel("Frame Index") |
|
|
ax.set_ylabel("Translation Error (m)") |
|
|
ax.set_title("Translation Error Comparison") |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = output_dir / "error_comparison.png" |
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight") |
|
|
print(f"Saved error comparison to {output_path}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def visualize_matches( |
|
|
image_path1: Path, |
|
|
image_path2: Path, |
|
|
matches_path: Path, |
|
|
features_path: Path, |
|
|
output_path: Path, |
|
|
): |
|
|
"""Visualize feature matches between two images.""" |
|
|
import h5py |
|
|
|
|
|
|
|
|
img1 = cv2.imread(str(image_path1)) |
|
|
img2 = cv2.imread(str(image_path2)) |
|
|
|
|
|
if img1 is None or img2 is None: |
|
|
print(f"Could not load images: {image_path1}, {image_path2}") |
|
|
return |
|
|
|
|
|
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
|
|
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
with h5py.File(features_path, "r") as f: |
|
|
kp1 = f[Path(image_path1).name]["keypoints"][:] |
|
|
kp2 = f[Path(image_path2).name]["keypoints"][:] |
|
|
|
|
|
with h5py.File(matches_path, "r") as f: |
|
|
pair_name = f"{Path(image_path1).name} {Path(image_path2).name}" |
|
|
if pair_name in f: |
|
|
matches = f[pair_name]["matches0"][:] |
|
|
else: |
|
|
|
|
|
pair_name = f"{Path(image_path2).name} {Path(image_path1).name}" |
|
|
if pair_name in f: |
|
|
matches = f[pair_name]["matches0"][:] |
|
|
else: |
|
|
print( |
|
|
f"No matches found for pair: {Path(image_path1).name} " |
|
|
f"<-> {Path(image_path2).name}" |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
valid = matches > -1 |
|
|
matches1 = np.where(valid)[0] |
|
|
matches2 = matches[valid] |
|
|
|
|
|
|
|
|
h1, w1 = img1.shape[:2] |
|
|
h2, w2 = img2.shape[:2] |
|
|
vis = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8) |
|
|
vis[:h1, :w1] = img1 |
|
|
vis[:h2, w1:] = img2 |
|
|
|
|
|
|
|
|
for i, (m1, m2) in enumerate(zip(matches1, matches2)): |
|
|
pt1 = tuple(kp1[m1].astype(int)) |
|
|
pt2 = tuple((kp2[m2] + [w1, 0]).astype(int)) |
|
|
|
|
|
color = np.random.randint(0, 255, 3).tolist() |
|
|
cv2.circle(vis, pt1, 5, color, -1) |
|
|
cv2.circle(vis, pt2, 5, color, -1) |
|
|
cv2.line(vis, pt1, pt2, color, 1) |
|
|
|
|
|
|
|
|
vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite(str(output_path), vis_bgr) |
|
|
print(f"Saved match visualization to {output_path}") |
|
|
|
|
|
|
|
|
def create_summary_report(results: Dict, output_dir: Path): |
|
|
"""Create a text summary report.""" |
|
|
report_path = output_dir / "summary_report.txt" |
|
|
|
|
|
with open(report_path, "w") as f: |
|
|
f.write("=" * 60 + "\n") |
|
|
f.write("BA Validation Summary Report\n") |
|
|
f.write("=" * 60 + "\n\n") |
|
|
|
|
|
f.write(f"Total Frames: {results.get('num_frames', 'N/A')}\n\n") |
|
|
|
|
|
|
|
|
f.write("DA3 vs ARKit (Ground Truth):\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
da3_vs_arkit = results["da3_vs_arkit"] |
|
|
f.write(f" Mean Rotation Error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Max Rotation Error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Mean Translation Error: {da3_vs_arkit['mean_translation_error']:.4f} m\n\n") |
|
|
|
|
|
|
|
|
if "ba_vs_arkit" in results: |
|
|
f.write("BA vs ARKit (Ground Truth):\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
ba_vs_arkit = results["ba_vs_arkit"] |
|
|
f.write(f" Mean Rotation Error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Max Rotation Error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Mean Translation Error: {ba_vs_arkit['mean_translation_error']:.4f} m\n\n") |
|
|
|
|
|
|
|
|
if "da3_vs_ba" in results: |
|
|
f.write("DA3 vs BA:\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
da3_vs_ba = results["da3_vs_ba"] |
|
|
f.write(f" Mean Rotation Error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Max Rotation Error: {da3_vs_ba['max_rotation_error_deg']:.2f}°\n") |
|
|
f.write(f" Mean Translation Error: {da3_vs_ba['mean_translation_error']:.4f} m\n\n") |
|
|
|
|
|
|
|
|
if "ba_result" in results: |
|
|
f.write("BA Validation Result:\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
ba_result = results["ba_result"] |
|
|
f.write(f" Status: {ba_result.get('status', 'N/A')}\n") |
|
|
f.write(f" Error: {ba_result.get('error', 'N/A')}\n") |
|
|
f.write(f" Reprojection Error: {ba_result.get('reprojection_error', 'N/A')}\n\n") |
|
|
|
|
|
|
|
|
if "da3_vs_arkit" in results: |
|
|
errors = results["da3_vs_arkit"]["rotation_errors_deg"] |
|
|
accepted = sum(1 for e in errors if e < 2.0) |
|
|
learnable = sum(1 for e in errors if 2.0 <= e < 30.0) |
|
|
outlier = sum(1 for e in errors if e >= 30.0) |
|
|
|
|
|
f.write("Frame Categorization (DA3 vs ARKit):\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
accepted_pct = 100 * accepted / len(errors) |
|
|
learnable_pct = 100 * learnable / len(errors) |
|
|
outlier_pct = 100 * outlier / len(errors) |
|
|
f.write(f" Accepted (< 2°): {accepted}/{len(errors)} " f"({accepted_pct:.1f}%)\n") |
|
|
f.write(f" Learnable (2-30°): {learnable}/{len(errors)} " f"({learnable_pct:.1f}%)\n") |
|
|
f.write(f" Outlier (> 30°): {outlier}/{len(errors)} " f"({outlier_pct:.1f}%)\n") |
|
|
|
|
|
print(f"Saved summary report to {report_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Visualize BA validation results") |
|
|
parser.add_argument( |
|
|
"--results-dir", |
|
|
type=Path, |
|
|
default=project_root / "data" / "arkit_ba_validation", |
|
|
help="Directory containing validation results", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=Path, |
|
|
default=None, |
|
|
help="Output directory for visualizations (default: results_dir/visualizations)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--use-plotly", |
|
|
action="store_true", |
|
|
help="Use plotly for interactive 3D plots (if available)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
results_path = args.results_dir / "validation_results.json" |
|
|
if not results_path.exists(): |
|
|
print(f"Results file not found: {results_path}") |
|
|
return |
|
|
|
|
|
output_dir = args.output_dir or (args.results_dir / "visualizations") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"Loading results from {results_path}") |
|
|
results = load_results(results_path) |
|
|
|
|
|
|
|
|
arkit_poses_path = args.results_dir / "arkit_poses_c2w.npy" |
|
|
da3_poses_path = args.results_dir / "da3_poses_w2c.npy" |
|
|
ba_poses_path = args.results_dir / "ba_poses_w2c.npy" |
|
|
|
|
|
arkit_poses = None |
|
|
da3_poses = None |
|
|
ba_poses = None |
|
|
|
|
|
if arkit_poses_path.exists(): |
|
|
arkit_poses = np.load(arkit_poses_path) |
|
|
print(f"Loaded ARKit poses: {arkit_poses.shape}") |
|
|
|
|
|
if da3_poses_path.exists(): |
|
|
da3_poses = np.load(da3_poses_path) |
|
|
print(f"Loaded DA3 poses: {da3_poses.shape}") |
|
|
|
|
|
if ba_poses_path.exists(): |
|
|
ba_poses = np.load(ba_poses_path) |
|
|
print(f"Loaded BA poses: {ba_poses.shape}") |
|
|
|
|
|
|
|
|
print("\nCreating visualizations...") |
|
|
|
|
|
|
|
|
plot_error_metrics(results, output_dir) |
|
|
plot_error_comparison(results, output_dir) |
|
|
|
|
|
|
|
|
create_summary_report(results, output_dir) |
|
|
|
|
|
|
|
|
if arkit_poses is not None and da3_poses is not None: |
|
|
try: |
|
|
plot_trajectories_3d( |
|
|
arkit_poses, |
|
|
da3_poses, |
|
|
ba_poses=ba_poses, |
|
|
output_path=( |
|
|
output_dir / "trajectories_3d.html" |
|
|
if (args.use_plotly and HAS_PLOTLY) |
|
|
else output_dir / "trajectories_3d.png" |
|
|
), |
|
|
use_plotly=args.use_plotly and HAS_PLOTLY, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error creating trajectory plot: {e}") |
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
else: |
|
|
print("Skipping trajectory visualization (poses not available)") |
|
|
|
|
|
print(f"\n✓ Visualizations saved to {output_dir}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|