|
|
""" |
|
|
Real-time GUI visualization for BA validation. |
|
|
Updates progressively as data comes in. |
|
|
""" |
|
|
|
|
|
import queue |
|
|
import tkinter as tk |
|
|
from tkinter import ttk |
|
|
from typing import Dict, List, Optional |
|
|
import numpy as np |
|
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk |
|
|
from matplotlib.figure import Figure |
|
|
|
|
|
|
|
|
class BAValidationGUI: |
|
|
"""Real-time GUI for BA validation visualization.""" |
|
|
|
|
|
def __init__(self, root: tk.Tk): |
|
|
self.root = root |
|
|
self.root.title("BA Validation - Real-time Visualization") |
|
|
self.root.geometry("1400x900") |
|
|
|
|
|
|
|
|
self.arkit_poses = [] |
|
|
self.da3_poses = [] |
|
|
self.ba_poses = [] |
|
|
self.error_data = { |
|
|
"da3_vs_arkit_rot": [], |
|
|
"da3_vs_arkit_trans": [], |
|
|
"ba_vs_arkit_rot": [], |
|
|
"ba_vs_arkit_trans": [], |
|
|
"da3_vs_ba_rot": [], |
|
|
"da3_vs_ba_trans": [], |
|
|
} |
|
|
self.frame_indices = [] |
|
|
self.status_text = [] |
|
|
|
|
|
|
|
|
self.update_queue = queue.Queue() |
|
|
|
|
|
|
|
|
self._setup_gui() |
|
|
|
|
|
|
|
|
self.root.after(100, self._process_updates) |
|
|
|
|
|
def _setup_gui(self): |
|
|
"""Setup the GUI layout.""" |
|
|
|
|
|
main_frame = ttk.Frame(self.root, padding="10") |
|
|
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) |
|
|
self.root.columnconfigure(0, weight=1) |
|
|
self.root.rowconfigure(0, weight=1) |
|
|
|
|
|
|
|
|
left_panel = ttk.Frame(main_frame) |
|
|
left_panel.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), padx=(0, 10)) |
|
|
|
|
|
|
|
|
status_frame = ttk.LabelFrame(left_panel, text="Status", padding="10") |
|
|
status_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 10)) |
|
|
|
|
|
self.status_label = ttk.Label(status_frame, text="Waiting for data...", font=("Arial", 12)) |
|
|
self.status_label.pack(anchor=tk.W) |
|
|
|
|
|
self.progress_var = tk.StringVar(value="0/0 frames") |
|
|
self.progress_label = ttk.Label( |
|
|
status_frame, textvariable=self.progress_var, font=("Arial", 10) |
|
|
) |
|
|
self.progress_label.pack(anchor=tk.W, pady=(5, 0)) |
|
|
|
|
|
self.progress_bar = ttk.Progressbar(status_frame, mode="indeterminate") |
|
|
self.progress_bar.pack(fill=tk.X, pady=(5, 0)) |
|
|
|
|
|
|
|
|
stats_frame = ttk.LabelFrame(left_panel, text="Statistics", padding="10") |
|
|
stats_frame.pack(fill=tk.BOTH, expand=True) |
|
|
|
|
|
self.stats_text = tk.Text(stats_frame, height=15, width=30, font=("Courier", 9)) |
|
|
self.stats_text.pack(fill=tk.BOTH, expand=True) |
|
|
scrollbar = ttk.Scrollbar(stats_frame, orient=tk.VERTICAL, command=self.stats_text.yview) |
|
|
self.stats_text.configure(yscrollcommand=scrollbar.set) |
|
|
scrollbar.pack(side=tk.RIGHT, fill=tk.Y) |
|
|
|
|
|
|
|
|
right_panel = ttk.Frame(main_frame) |
|
|
right_panel.grid(row=0, column=1, sticky=(tk.W, tk.E, tk.N, tk.S)) |
|
|
main_frame.columnconfigure(1, weight=1) |
|
|
main_frame.rowconfigure(0, weight=1) |
|
|
|
|
|
|
|
|
self.notebook = ttk.Notebook(right_panel) |
|
|
self.notebook.pack(fill=tk.BOTH, expand=True) |
|
|
|
|
|
|
|
|
traj_frame = ttk.Frame(self.notebook) |
|
|
self.notebook.add(traj_frame, text="3D Trajectories") |
|
|
|
|
|
self.fig_3d = Figure(figsize=(10, 8), dpi=100) |
|
|
self.ax_3d = self.fig_3d.add_subplot(111, projection="3d") |
|
|
self.canvas_3d = FigureCanvasTkAgg(self.fig_3d, traj_frame) |
|
|
self.canvas_3d.draw() |
|
|
self.canvas_3d.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True) |
|
|
toolbar_3d = NavigationToolbar2Tk(self.canvas_3d, traj_frame) |
|
|
toolbar_3d.update() |
|
|
|
|
|
|
|
|
error_frame = ttk.Frame(self.notebook) |
|
|
self.notebook.add(error_frame, text="Error Metrics") |
|
|
|
|
|
self.fig_errors = Figure(figsize=(10, 8), dpi=100) |
|
|
self.ax_errors = self.fig_errors.add_subplot(111) |
|
|
self.canvas_errors = FigureCanvasTkAgg(self.fig_errors, error_frame) |
|
|
self.canvas_errors.draw() |
|
|
self.canvas_errors.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True) |
|
|
toolbar_errors = NavigationToolbar2Tk(self.canvas_errors, error_frame) |
|
|
toolbar_errors.update() |
|
|
|
|
|
|
|
|
comp_frame = ttk.Frame(self.notebook) |
|
|
self.notebook.add(comp_frame, text="Comparison") |
|
|
|
|
|
self.fig_comp = Figure(figsize=(10, 8), dpi=100) |
|
|
self.ax_comp = self.fig_comp.add_subplot(111) |
|
|
self.canvas_comp = FigureCanvasTkAgg(self.fig_comp, comp_frame) |
|
|
self.canvas_comp.draw() |
|
|
self.canvas_comp.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True) |
|
|
toolbar_comp = NavigationToolbar2Tk(self.canvas_comp, comp_frame) |
|
|
toolbar_comp.update() |
|
|
|
|
|
|
|
|
self._init_plots() |
|
|
|
|
|
def _init_plots(self): |
|
|
"""Initialize empty plots.""" |
|
|
|
|
|
self.ax_3d.set_xlabel("X (m)") |
|
|
self.ax_3d.set_ylabel("Y (m)") |
|
|
self.ax_3d.set_zlabel("Z (m)") |
|
|
self.ax_3d.set_title("Camera Trajectories (3D)") |
|
|
self.ax_3d.legend() |
|
|
self.ax_3d.grid(True) |
|
|
|
|
|
|
|
|
self.ax_errors.set_xlabel("Frame Index") |
|
|
self.ax_errors.set_ylabel("Rotation Error (degrees)") |
|
|
self.ax_errors.set_title("Rotation Errors") |
|
|
self.ax_errors.axhline(y=2.0, color="g", linestyle="--", alpha=0.5, label="Accept (2°)") |
|
|
self.ax_errors.axhline( |
|
|
y=30.0, color="orange", linestyle="--", alpha=0.5, label="Reject (30°)" |
|
|
) |
|
|
self.ax_errors.legend() |
|
|
self.ax_errors.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
self.ax_comp.set_xlabel("Frame Index") |
|
|
self.ax_comp.set_ylabel("Error") |
|
|
self.ax_comp.set_title("Error Comparison") |
|
|
self.ax_comp.legend() |
|
|
self.ax_comp.grid(True, alpha=0.3) |
|
|
|
|
|
def update_status(self, message: str, is_processing: bool = False): |
|
|
"""Update status message.""" |
|
|
self.status_label.config(text=message) |
|
|
if is_processing: |
|
|
self.progress_bar.start(10) |
|
|
else: |
|
|
self.progress_bar.stop() |
|
|
|
|
|
def update_progress(self, current: int, total: int): |
|
|
"""Update progress indicator.""" |
|
|
self.progress_var.set(f"{current}/{total} frames") |
|
|
|
|
|
def add_frame_data( |
|
|
self, |
|
|
frame_idx: int, |
|
|
arkit_pose: Optional[np.ndarray] = None, |
|
|
da3_pose: Optional[np.ndarray] = None, |
|
|
ba_pose: Optional[np.ndarray] = None, |
|
|
errors: Optional[Dict] = None, |
|
|
): |
|
|
"""Add data for a new frame (thread-safe).""" |
|
|
update_data = { |
|
|
"type": "frame_data", |
|
|
"frame_idx": frame_idx, |
|
|
"arkit_pose": arkit_pose, |
|
|
"da3_pose": da3_pose, |
|
|
"ba_pose": ba_pose, |
|
|
"errors": errors, |
|
|
} |
|
|
self.update_queue.put(update_data) |
|
|
|
|
|
def add_status_message(self, message: str): |
|
|
"""Add status message (thread-safe).""" |
|
|
self.update_queue.put({"type": "status", "message": message}) |
|
|
|
|
|
def add_progress_update(self, current: int, total: int): |
|
|
"""Add progress update (thread-safe).""" |
|
|
self.update_queue.put({"type": "progress", "current": current, "total": total}) |
|
|
|
|
|
def _process_updates(self): |
|
|
"""Process updates from queue (called from main thread).""" |
|
|
try: |
|
|
while True: |
|
|
update = self.update_queue.get_nowait() |
|
|
|
|
|
if update["type"] == "frame_data": |
|
|
self._process_frame_data(update) |
|
|
elif update["type"] == "status": |
|
|
self.update_status(update["message"], is_processing=True) |
|
|
elif update["type"] == "progress": |
|
|
self.update_progress(update["current"], update["total"]) |
|
|
|
|
|
except queue.Empty: |
|
|
pass |
|
|
|
|
|
|
|
|
self.root.after(100, self._process_updates) |
|
|
|
|
|
def _process_frame_data(self, update: Dict): |
|
|
"""Process frame data update.""" |
|
|
frame_idx = update["frame_idx"] |
|
|
|
|
|
if update["arkit_pose"] is not None: |
|
|
self.arkit_poses.append(update["arkit_pose"]) |
|
|
if update["da3_pose"] is not None: |
|
|
self.da3_poses.append(update["da3_pose"]) |
|
|
if update["ba_pose"] is not None: |
|
|
self.ba_poses.append(update["ba_pose"]) |
|
|
|
|
|
if update["errors"]: |
|
|
for key, value in update["errors"].items(): |
|
|
if key in self.error_data: |
|
|
self.error_data[key].append(value) |
|
|
|
|
|
self.frame_indices.append(frame_idx) |
|
|
|
|
|
|
|
|
self._update_plots() |
|
|
self._update_statistics() |
|
|
|
|
|
def _get_camera_centers(self, poses: List[np.ndarray]) -> np.ndarray: |
|
|
"""Extract camera centers from poses.""" |
|
|
if not poses: |
|
|
return np.array([]).reshape(0, 3) |
|
|
|
|
|
centers = [] |
|
|
for pose in poses: |
|
|
if pose is None: |
|
|
continue |
|
|
pose_arr = np.array(pose) |
|
|
if pose_arr.shape == (4, 4): |
|
|
|
|
|
centers.append(pose_arr[:3, 3]) |
|
|
elif pose_arr.shape == (3, 4): |
|
|
|
|
|
R = pose_arr[:3, :3] |
|
|
t = pose_arr[:3, 3] |
|
|
c = -R.T @ t |
|
|
centers.append(c) |
|
|
else: |
|
|
continue |
|
|
|
|
|
return np.array(centers) if centers else np.array([]).reshape(0, 3) |
|
|
|
|
|
def _update_plots(self): |
|
|
"""Update all plots.""" |
|
|
|
|
|
self.ax_3d.clear() |
|
|
self.ax_3d.set_xlabel("X (m)") |
|
|
self.ax_3d.set_ylabel("Y (m)") |
|
|
self.ax_3d.set_zlabel("Z (m)") |
|
|
self.ax_3d.set_title("Camera Trajectories (3D)") |
|
|
|
|
|
if self.arkit_poses: |
|
|
centers_arkit = self._get_camera_centers(self.arkit_poses) |
|
|
if len(centers_arkit) > 0: |
|
|
self.ax_3d.plot( |
|
|
centers_arkit[:, 0], |
|
|
centers_arkit[:, 1], |
|
|
centers_arkit[:, 2], |
|
|
"g-", |
|
|
linewidth=2, |
|
|
marker="o", |
|
|
markersize=4, |
|
|
label="ARKit (GT)", |
|
|
) |
|
|
|
|
|
if self.da3_poses: |
|
|
centers_da3 = self._get_camera_centers(self.da3_poses) |
|
|
if len(centers_da3) > 0: |
|
|
self.ax_3d.plot( |
|
|
centers_da3[:, 0], |
|
|
centers_da3[:, 1], |
|
|
centers_da3[:, 2], |
|
|
"r-", |
|
|
linewidth=1, |
|
|
marker="s", |
|
|
markersize=3, |
|
|
label="DA3", |
|
|
) |
|
|
|
|
|
if self.ba_poses: |
|
|
centers_ba = self._get_camera_centers(self.ba_poses) |
|
|
if len(centers_ba) > 0: |
|
|
self.ax_3d.plot( |
|
|
centers_ba[:, 0], |
|
|
centers_ba[:, 1], |
|
|
centers_ba[:, 2], |
|
|
"b-", |
|
|
linewidth=1, |
|
|
marker="^", |
|
|
markersize=3, |
|
|
label="BA", |
|
|
) |
|
|
|
|
|
self.ax_3d.legend() |
|
|
self.ax_3d.grid(True) |
|
|
self.canvas_3d.draw() |
|
|
|
|
|
|
|
|
self.ax_errors.clear() |
|
|
self.ax_errors.set_xlabel("Frame Index") |
|
|
self.ax_errors.set_ylabel("Rotation Error (degrees)") |
|
|
self.ax_errors.set_title("Rotation Errors") |
|
|
self.ax_errors.axhline(y=2.0, color="g", linestyle="--", alpha=0.5, label="Accept (2°)") |
|
|
self.ax_errors.axhline( |
|
|
y=30.0, color="orange", linestyle="--", alpha=0.5, label="Reject (30°)" |
|
|
) |
|
|
|
|
|
if self.error_data["da3_vs_arkit_rot"]: |
|
|
self.ax_errors.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["da3_vs_arkit_rot"], |
|
|
"r-o", |
|
|
linewidth=2, |
|
|
markersize=4, |
|
|
label="DA3 vs ARKit", |
|
|
) |
|
|
|
|
|
if self.error_data["ba_vs_arkit_rot"]: |
|
|
self.ax_errors.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["ba_vs_arkit_rot"], |
|
|
"b-o", |
|
|
linewidth=2, |
|
|
markersize=4, |
|
|
label="BA vs ARKit", |
|
|
) |
|
|
|
|
|
self.ax_errors.legend() |
|
|
self.ax_errors.grid(True, alpha=0.3) |
|
|
self.canvas_errors.draw() |
|
|
|
|
|
|
|
|
self.ax_comp.clear() |
|
|
self.ax_comp.set_xlabel("Frame Index") |
|
|
self.ax_comp.set_ylabel("Error") |
|
|
self.ax_comp.set_title("Error Comparison") |
|
|
|
|
|
if self.error_data["da3_vs_arkit_rot"]: |
|
|
self.ax_comp.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["da3_vs_arkit_rot"], |
|
|
"r-o", |
|
|
linewidth=2, |
|
|
markersize=4, |
|
|
label="DA3 vs ARKit (Rot)", |
|
|
) |
|
|
|
|
|
if self.error_data["da3_vs_arkit_trans"]: |
|
|
self.ax_comp.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["da3_vs_arkit_trans"], |
|
|
"r--s", |
|
|
linewidth=1, |
|
|
markersize=3, |
|
|
label="DA3 vs ARKit (Trans)", |
|
|
) |
|
|
|
|
|
if self.error_data["ba_vs_arkit_rot"]: |
|
|
self.ax_comp.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["ba_vs_arkit_rot"], |
|
|
"b-o", |
|
|
linewidth=2, |
|
|
markersize=4, |
|
|
label="BA vs ARKit (Rot)", |
|
|
) |
|
|
|
|
|
if self.error_data["ba_vs_arkit_trans"]: |
|
|
self.ax_comp.plot( |
|
|
self.frame_indices, |
|
|
self.error_data["ba_vs_arkit_trans"], |
|
|
"b--s", |
|
|
linewidth=1, |
|
|
markersize=3, |
|
|
label="BA vs ARKit (Trans)", |
|
|
) |
|
|
|
|
|
self.ax_comp.legend() |
|
|
self.ax_comp.grid(True, alpha=0.3) |
|
|
self.canvas_comp.draw() |
|
|
|
|
|
def _update_statistics(self): |
|
|
"""Update statistics text.""" |
|
|
self.stats_text.delete(1.0, tk.END) |
|
|
|
|
|
if not self.frame_indices: |
|
|
self.stats_text.insert(tk.END, "No data yet...") |
|
|
return |
|
|
|
|
|
stats = [] |
|
|
stats.append(f"Frames Processed: {len(self.frame_indices)}") |
|
|
stats.append("") |
|
|
|
|
|
if self.error_data["da3_vs_arkit_rot"]: |
|
|
errors = self.error_data["da3_vs_arkit_rot"] |
|
|
stats.append("DA3 vs ARKit:") |
|
|
stats.append(f" Mean Rot Error: {np.mean(errors):.2f}°") |
|
|
stats.append(f" Max Rot Error: {np.max(errors):.2f}°") |
|
|
if self.error_data["da3_vs_arkit_trans"]: |
|
|
trans_errors = self.error_data["da3_vs_arkit_trans"] |
|
|
stats.append(f" Mean Trans Error: {np.mean(trans_errors):.4f} m") |
|
|
stats.append("") |
|
|
|
|
|
if self.error_data["ba_vs_arkit_rot"]: |
|
|
errors = self.error_data["ba_vs_arkit_rot"] |
|
|
stats.append("BA vs ARKit:") |
|
|
stats.append(f" Mean Rot Error: {np.mean(errors):.2f}°") |
|
|
stats.append(f" Max Rot Error: {np.max(errors):.2f}°") |
|
|
if self.error_data["ba_vs_arkit_trans"]: |
|
|
trans_errors = self.error_data["ba_vs_arkit_trans"] |
|
|
stats.append(f" Mean Trans Error: {np.mean(trans_errors):.4f} m") |
|
|
stats.append("") |
|
|
|
|
|
if self.error_data["da3_vs_ba_rot"]: |
|
|
errors = self.error_data["da3_vs_ba_rot"] |
|
|
stats.append("DA3 vs BA:") |
|
|
stats.append(f" Mean Rot Error: {np.mean(errors):.2f}°") |
|
|
stats.append(f" Max Rot Error: {np.max(errors):.2f}°") |
|
|
if self.error_data["da3_vs_ba_trans"]: |
|
|
trans_errors = self.error_data["da3_vs_ba_trans"] |
|
|
stats.append(f" Mean Trans Error: {np.mean(trans_errors):.4f} m") |
|
|
|
|
|
self.stats_text.insert(tk.END, "\n".join(stats)) |
|
|
self.stats_text.see(tk.END) |
|
|
|
|
|
def run(self): |
|
|
"""Start the GUI main loop.""" |
|
|
self.root.mainloop() |
|
|
|
|
|
|
|
|
def create_gui() -> BAValidationGUI: |
|
|
"""Create and return a GUI instance.""" |
|
|
root = tk.Tk() |
|
|
gui = BAValidationGUI(root) |
|
|
return gui |
|
|
|