CrossKEY / visualization.py
morozovdd's picture
Initial deploy: CrossKEY interactive 3D matching demo
ffbfad7
"""Plotly 3D visualization for CrossKEY matching results.
Builds side-by-side volume isosurfaces with keypoints and match lines.
MR volume on the left, US volume on the right, offset along the X axis.
"""
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import zoom
from skimage.measure import marching_cubes
def downsample_volume(volume: np.ndarray, target_size: int = 64) -> np.ndarray:
"""Downsample volume to target_size^3 for browser-friendly rendering."""
factors = [target_size / s for s in volume.shape]
return zoom(volume, factors, order=1).astype(np.float32)
def scale_points(
points: np.ndarray,
padded_shape: tuple,
volume_shape: tuple,
) -> np.ndarray:
"""Scale point coordinates from padded volume space to downsampled volume space."""
scale = np.array(volume_shape, dtype=float) / np.array(padded_shape, dtype=float)
return points * scale
def create_isosurface_trace(
volume: np.ndarray,
level: float,
colorscale: str = "Gray",
opacity: float = 0.15,
name: str = "",
offset_x: float = 0.0,
) -> go.Mesh3d:
"""Create a Mesh3d trace from a volume isosurface via marching cubes.
Uses vertex intensity from the original volume for natural coloring.
"""
verts, faces, _, _ = marching_cubes(volume, level=level)
# Sample volume intensity at each vertex for natural coloring
vi = np.clip(verts.astype(int), 0, np.array(volume.shape) - 1)
intensities = volume[vi[:, 0], vi[:, 1], vi[:, 2]]
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 so cone points up
return go.Mesh3d(
x=verts[:, 1] + offset_x,
y=verts[:, 2],
z=-verts[:, 0],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
intensity=intensities,
colorscale=colorscale,
opacity=opacity,
name=name,
showlegend=True,
showscale=False,
)
def create_keypoint_trace(
points: np.ndarray,
color: str,
size: float = 3.0,
opacity: float = 1.0,
name: str = "",
offset_x: float = 0.0,
) -> go.Scatter3d:
"""Create Scatter3d markers for keypoints."""
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
return go.Scatter3d(
x=points[:, 1] + offset_x,
y=points[:, 2],
z=-points[:, 0],
mode="markers",
marker=dict(size=size, color=color, opacity=opacity),
name=name,
showlegend=True,
)
def create_match_lines(
src_pts: np.ndarray,
tgt_pts: np.ndarray,
color: str,
width: float = 2.0,
name: str = "",
offset_x: float = 0.0,
) -> go.Scatter3d:
"""Create lines connecting matched source points to offset target points."""
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
lx, ly, lz = [], [], []
for s, t in zip(src_pts, tgt_pts):
lx.extend([float(s[1]), float(t[1]) + offset_x, None])
ly.extend([float(s[2]), float(t[2]), None])
lz.extend([-float(s[0]), -float(t[0]), None])
return go.Scatter3d(
x=lx, y=ly, z=lz,
mode="lines",
line=dict(color=color, width=width),
name=name,
showlegend=True,
)
def build_matching_figure(
volume_mr: np.ndarray,
volume_us: np.ndarray,
points_mr: np.ndarray,
points_us: np.ndarray,
padded_shape_mr: tuple,
padded_shape_us: tuple,
match_pairs: list,
metrics: dict,
evaluation_threshold: float = 5.0,
mr_level: float = 0.3,
us_level: float = 0.1,
) -> go.Figure:
"""Build the full 3D matching visualization."""
fig = go.Figure()
# Scale keypoints to match downsampled volume coordinates
pts_mr_viz = scale_points(points_mr, padded_shape_mr, volume_mr.shape)
pts_us_viz = scale_points(points_us, padded_shape_us, volume_us.shape)
# Side-by-side offset along Plotly x (= data axis 1)
gap = volume_mr.shape[1] * 0.3
offset_x = volume_mr.shape[1] + gap
# Volume isosurfaces with natural intensity coloring
try:
fig.add_trace(create_isosurface_trace(
volume_mr, level=mr_level, colorscale="Gray",
opacity=0.15, name="MR Surface",
))
except ValueError:
pass
try:
fig.add_trace(create_isosurface_trace(
volume_us, level=us_level, colorscale="Hot",
opacity=0.15, name="US Surface", offset_x=offset_x,
))
except ValueError:
pass
# Process matches
src_indices = [p[0] for p in match_pairs]
tgt_indices = [p[1] for p in match_pairs]
if match_pairs:
mr_matched = points_mr[src_indices]
us_matched = points_us[tgt_indices]
spatial_dist = np.linalg.norm(mr_matched - us_matched, axis=1)
correct = spatial_dist < evaluation_threshold
mr_matched_viz = pts_mr_viz[src_indices]
us_matched_viz = pts_us_viz[tgt_indices]
if correct.any():
fig.add_trace(create_match_lines(
mr_matched_viz[correct], us_matched_viz[correct],
color="rgba(0,200,0,0.6)", width=2,
name=f"Correct ({correct.sum()})", offset_x=offset_x,
))
if (~correct).any():
fig.add_trace(create_match_lines(
mr_matched_viz[~correct], us_matched_viz[~correct],
color="rgba(255,0,0,0.3)", width=1,
name=f"Incorrect ({(~correct).sum()})", offset_x=offset_x,
))
fig.add_trace(create_keypoint_trace(
mr_matched_viz, color="royalblue", size=4,
name=f"MR Matched ({len(mr_matched_viz)})",
))
fig.add_trace(create_keypoint_trace(
us_matched_viz, color="crimson", size=4,
name=f"US Matched ({len(us_matched_viz)})", offset_x=offset_x,
))
# Unmatched keypoints (faded)
matched_mr_set = set(src_indices)
matched_us_set = set(tgt_indices)
unmatched_mr = np.array([i not in matched_mr_set for i in range(len(pts_mr_viz))])
unmatched_us = np.array([i not in matched_us_set for i in range(len(pts_us_viz))])
if unmatched_mr.any():
fig.add_trace(create_keypoint_trace(
pts_mr_viz[unmatched_mr], color="royalblue",
size=1.5, opacity=0.2, name="MR Unmatched",
))
if unmatched_us.any():
fig.add_trace(create_keypoint_trace(
pts_us_viz[unmatched_us], color="crimson",
size=1.5, opacity=0.2, name="US Unmatched", offset_x=offset_x,
))
# Layout -- no fixed width so Plotly fills the Gradio container
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
aspectmode="data",
camera=dict(
up=dict(x=0, y=0, z=1),
eye=dict(x=0, y=-1.8, z=0.3),
),
),
height=700,
margin=dict(l=0, r=0, t=40, b=0),
legend=dict(
yanchor="top", y=0.99,
xanchor="left", x=0.01,
bgcolor="rgba(0,0,0,0.5)",
font=dict(color="white"),
),
)
return fig