"""Plotly 3D visualization for CrossKEY matching results. Builds side-by-side volume isosurfaces with keypoints and match lines. MR volume on the left, US volume on the right, offset along the X axis. """ import numpy as np import plotly.graph_objects as go from scipy.ndimage import zoom from skimage.measure import marching_cubes def downsample_volume(volume: np.ndarray, target_size: int = 64) -> np.ndarray: """Downsample volume to target_size^3 for browser-friendly rendering.""" factors = [target_size / s for s in volume.shape] return zoom(volume, factors, order=1).astype(np.float32) def scale_points( points: np.ndarray, padded_shape: tuple, volume_shape: tuple, ) -> np.ndarray: """Scale point coordinates from padded volume space to downsampled volume space.""" scale = np.array(volume_shape, dtype=float) / np.array(padded_shape, dtype=float) return points * scale def create_isosurface_trace( volume: np.ndarray, level: float, colorscale: str = "Gray", opacity: float = 0.15, name: str = "", offset_x: float = 0.0, ) -> go.Mesh3d: """Create a Mesh3d trace from a volume isosurface via marching cubes. Uses vertex intensity from the original volume for natural coloring. """ verts, faces, _, _ = marching_cubes(volume, level=level) # Sample volume intensity at each vertex for natural coloring vi = np.clip(verts.astype(int), 0, np.array(volume.shape) - 1) intensities = volume[vi[:, 0], vi[:, 1], vi[:, 2]] # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 so cone points up return go.Mesh3d( x=verts[:, 1] + offset_x, y=verts[:, 2], z=-verts[:, 0], i=faces[:, 0], j=faces[:, 1], k=faces[:, 2], intensity=intensities, colorscale=colorscale, opacity=opacity, name=name, showlegend=True, showscale=False, ) def create_keypoint_trace( points: np.ndarray, color: str, size: float = 3.0, opacity: float = 1.0, name: str = "", offset_x: float = 0.0, ) -> go.Scatter3d: """Create Scatter3d markers for keypoints.""" # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 return go.Scatter3d( x=points[:, 1] + offset_x, y=points[:, 2], z=-points[:, 0], mode="markers", marker=dict(size=size, color=color, opacity=opacity), name=name, showlegend=True, ) def create_match_lines( src_pts: np.ndarray, tgt_pts: np.ndarray, color: str, width: float = 2.0, name: str = "", offset_x: float = 0.0, ) -> go.Scatter3d: """Create lines connecting matched source points to offset target points.""" # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 lx, ly, lz = [], [], [] for s, t in zip(src_pts, tgt_pts): lx.extend([float(s[1]), float(t[1]) + offset_x, None]) ly.extend([float(s[2]), float(t[2]), None]) lz.extend([-float(s[0]), -float(t[0]), None]) return go.Scatter3d( x=lx, y=ly, z=lz, mode="lines", line=dict(color=color, width=width), name=name, showlegend=True, ) def build_matching_figure( volume_mr: np.ndarray, volume_us: np.ndarray, points_mr: np.ndarray, points_us: np.ndarray, padded_shape_mr: tuple, padded_shape_us: tuple, match_pairs: list, metrics: dict, evaluation_threshold: float = 5.0, mr_level: float = 0.3, us_level: float = 0.1, ) -> go.Figure: """Build the full 3D matching visualization.""" fig = go.Figure() # Scale keypoints to match downsampled volume coordinates pts_mr_viz = scale_points(points_mr, padded_shape_mr, volume_mr.shape) pts_us_viz = scale_points(points_us, padded_shape_us, volume_us.shape) # Side-by-side offset along Plotly x (= data axis 1) gap = volume_mr.shape[1] * 0.3 offset_x = volume_mr.shape[1] + gap # Volume isosurfaces with natural intensity coloring try: fig.add_trace(create_isosurface_trace( volume_mr, level=mr_level, colorscale="Gray", opacity=0.15, name="MR Surface", )) except ValueError: pass try: fig.add_trace(create_isosurface_trace( volume_us, level=us_level, colorscale="Hot", opacity=0.15, name="US Surface", offset_x=offset_x, )) except ValueError: pass # Process matches src_indices = [p[0] for p in match_pairs] tgt_indices = [p[1] for p in match_pairs] if match_pairs: mr_matched = points_mr[src_indices] us_matched = points_us[tgt_indices] spatial_dist = np.linalg.norm(mr_matched - us_matched, axis=1) correct = spatial_dist < evaluation_threshold mr_matched_viz = pts_mr_viz[src_indices] us_matched_viz = pts_us_viz[tgt_indices] if correct.any(): fig.add_trace(create_match_lines( mr_matched_viz[correct], us_matched_viz[correct], color="rgba(0,200,0,0.6)", width=2, name=f"Correct ({correct.sum()})", offset_x=offset_x, )) if (~correct).any(): fig.add_trace(create_match_lines( mr_matched_viz[~correct], us_matched_viz[~correct], color="rgba(255,0,0,0.3)", width=1, name=f"Incorrect ({(~correct).sum()})", offset_x=offset_x, )) fig.add_trace(create_keypoint_trace( mr_matched_viz, color="royalblue", size=4, name=f"MR Matched ({len(mr_matched_viz)})", )) fig.add_trace(create_keypoint_trace( us_matched_viz, color="crimson", size=4, name=f"US Matched ({len(us_matched_viz)})", offset_x=offset_x, )) # Unmatched keypoints (faded) matched_mr_set = set(src_indices) matched_us_set = set(tgt_indices) unmatched_mr = np.array([i not in matched_mr_set for i in range(len(pts_mr_viz))]) unmatched_us = np.array([i not in matched_us_set for i in range(len(pts_us_viz))]) if unmatched_mr.any(): fig.add_trace(create_keypoint_trace( pts_mr_viz[unmatched_mr], color="royalblue", size=1.5, opacity=0.2, name="MR Unmatched", )) if unmatched_us.any(): fig.add_trace(create_keypoint_trace( pts_us_viz[unmatched_us], color="crimson", size=1.5, opacity=0.2, name="US Unmatched", offset_x=offset_x, )) # Layout -- no fixed width so Plotly fills the Gradio container fig.update_layout( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), aspectmode="data", camera=dict( up=dict(x=0, y=0, z=1), eye=dict(x=0, y=-1.8, z=0.3), ), ), height=700, margin=dict(l=0, r=0, t=40, b=0), legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01, bgcolor="rgba(0,0,0,0.5)", font=dict(color="white"), ), ) return fig