Initial deploy: CrossKEY interactive 3D matching demo
Browse filesCo-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- README.md +11 -6
- app.py +269 -0
- inference.py +144 -0
- precomputed/descriptors_mr.pt +3 -0
- precomputed/descriptors_us.pt +3 -0
- precomputed/metadata.json +16 -0
- precomputed/points_mr.pt +3 -0
- precomputed/points_us.pt +3 -0
- precomputed/volume_mr.npy +3 -0
- precomputed/volume_us.npy +3 -0
- requirements.txt +4 -0
- visualization.py +223 -0
README.md
CHANGED
|
@@ -1,12 +1,17 @@
|
|
| 1 |
---
|
| 2 |
title: CrossKEY
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
-
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: CrossKEY
|
| 3 |
+
emoji: "\U0001f9e0"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.29.0"
|
| 8 |
app_file: app.py
|
| 9 |
+
license: mit
|
| 10 |
+
short_description: 3D Cross-modal Keypoint Matching (MR-US)
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# CrossKEY Demo
|
| 14 |
+
|
| 15 |
+
Interactive 3D visualization of cross-modal keypoint matching between MRI and intraoperative ultrasound.
|
| 16 |
+
|
| 17 |
+
[Paper (IEEE TMI)](https://doi.org/10.1109/TMI.2026.3680352) | [Code (GitHub)](https://github.com/morozovdd/CrossKEY) | [arXiv](https://arxiv.org/abs/2507.18551)
|
app.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CrossKEY HuggingFace Space -- Interactive 3D Keypoint Matching Demo.
|
| 2 |
+
|
| 3 |
+
Two-tab Gradio app:
|
| 4 |
+
Tab 1 (Explore): Pre-computed results with adjustable matching parameters.
|
| 5 |
+
Tab 2 (Your Data): Upload volumes + checkpoint for live inference.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Add space/ to path so local imports work both locally and on HF
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from inference import load_precomputed, run_inference, run_matching
|
| 19 |
+
from visualization import build_matching_figure
|
| 20 |
+
|
| 21 |
+
# ZeroGPU decorator -- no-op when running locally
|
| 22 |
+
try:
|
| 23 |
+
import spaces
|
| 24 |
+
gpu_decorator = spaces.GPU
|
| 25 |
+
except ImportError:
|
| 26 |
+
gpu_decorator = lambda fn: fn
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 29 |
+
logger = logging.getLogger("crosskey.app")
|
| 30 |
+
|
| 31 |
+
# -- Load pre-computed data on startup --
|
| 32 |
+
logger.info("Loading pre-computed demo data...")
|
| 33 |
+
DEMO_DATA = load_precomputed(
|
| 34 |
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "precomputed")
|
| 35 |
+
)
|
| 36 |
+
logger.info(
|
| 37 |
+
"Loaded: %d MR descriptors, %d US descriptors",
|
| 38 |
+
len(DEMO_DATA["descriptors_mr"]),
|
| 39 |
+
len(DEMO_DATA["descriptors_us"]),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def update_demo(
|
| 44 |
+
ratio_threshold: float,
|
| 45 |
+
evaluation_threshold: float,
|
| 46 |
+
mutual: bool,
|
| 47 |
+
metric: str,
|
| 48 |
+
) -> tuple:
|
| 49 |
+
"""Re-run matching with new parameters and rebuild the figure."""
|
| 50 |
+
match_pairs, metrics = run_matching(
|
| 51 |
+
DEMO_DATA["descriptors_mr"],
|
| 52 |
+
DEMO_DATA["descriptors_us"],
|
| 53 |
+
DEMO_DATA["points_mr"],
|
| 54 |
+
DEMO_DATA["points_us"],
|
| 55 |
+
ratio_threshold=ratio_threshold,
|
| 56 |
+
mutual=mutual,
|
| 57 |
+
metric=metric,
|
| 58 |
+
evaluation_threshold=evaluation_threshold,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
fig = build_matching_figure(
|
| 62 |
+
volume_mr=DEMO_DATA["volume_mr"],
|
| 63 |
+
volume_us=DEMO_DATA["volume_us"],
|
| 64 |
+
points_mr=DEMO_DATA["points_mr"],
|
| 65 |
+
points_us=DEMO_DATA["points_us"],
|
| 66 |
+
padded_shape_mr=tuple(DEMO_DATA["metadata"]["padded_shape_mr"]),
|
| 67 |
+
padded_shape_us=tuple(DEMO_DATA["metadata"]["padded_shape_us"]),
|
| 68 |
+
match_pairs=match_pairs,
|
| 69 |
+
metrics=metrics,
|
| 70 |
+
evaluation_threshold=evaluation_threshold,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return (
|
| 74 |
+
fig,
|
| 75 |
+
metrics['num_matches'],
|
| 76 |
+
metrics['num_correct'],
|
| 77 |
+
round(metrics['precision'], 1),
|
| 78 |
+
round(metrics['matching_score'], 4),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@gpu_decorator
|
| 83 |
+
def run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file):
|
| 84 |
+
"""Run inference on uploaded data. Uses ZeroGPU on HF Spaces."""
|
| 85 |
+
if any(f is None for f in [mr_file, us_file, heatmap_file, ckpt_file]):
|
| 86 |
+
raise gr.Error("Please upload all four files: MR volume, US volume, heatmap, and checkpoint.")
|
| 87 |
+
|
| 88 |
+
logger.info("Running inference on uploaded data...")
|
| 89 |
+
data = run_inference(
|
| 90 |
+
mr_path=mr_file.name,
|
| 91 |
+
us_path=us_file.name,
|
| 92 |
+
heatmap_path=heatmap_file.name,
|
| 93 |
+
checkpoint_path=ckpt_file.name,
|
| 94 |
+
)
|
| 95 |
+
return data
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def update_custom(
|
| 99 |
+
data: dict,
|
| 100 |
+
ratio_threshold: float,
|
| 101 |
+
evaluation_threshold: float,
|
| 102 |
+
mutual: bool,
|
| 103 |
+
metric: str,
|
| 104 |
+
) -> tuple:
|
| 105 |
+
"""Re-run matching on custom data with new parameters."""
|
| 106 |
+
if data is None:
|
| 107 |
+
raise gr.Error("Run inference first.")
|
| 108 |
+
|
| 109 |
+
match_pairs, metrics = run_matching(
|
| 110 |
+
data["descriptors_mr"],
|
| 111 |
+
data["descriptors_us"],
|
| 112 |
+
data["points_mr"],
|
| 113 |
+
data["points_us"],
|
| 114 |
+
ratio_threshold=ratio_threshold,
|
| 115 |
+
mutual=mutual,
|
| 116 |
+
metric=metric,
|
| 117 |
+
evaluation_threshold=evaluation_threshold,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
fig = build_matching_figure(
|
| 121 |
+
volume_mr=data["volume_mr"],
|
| 122 |
+
volume_us=data["volume_us"],
|
| 123 |
+
points_mr=data["points_mr"],
|
| 124 |
+
points_us=data["points_us"],
|
| 125 |
+
padded_shape_mr=tuple(data["metadata"]["padded_shape_mr"]),
|
| 126 |
+
padded_shape_us=tuple(data["metadata"]["padded_shape_us"]),
|
| 127 |
+
match_pairs=match_pairs,
|
| 128 |
+
metrics=metrics,
|
| 129 |
+
evaluation_threshold=evaluation_threshold,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return (
|
| 133 |
+
fig,
|
| 134 |
+
metrics['num_matches'],
|
| 135 |
+
metrics['num_correct'],
|
| 136 |
+
round(metrics['precision'], 1),
|
| 137 |
+
round(metrics['matching_score'], 4),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# -- Build Gradio UI --
|
| 142 |
+
|
| 143 |
+
with gr.Blocks(
|
| 144 |
+
title="CrossKEY -- 3D Cross-modal Keypoint Matching",
|
| 145 |
+
) as demo:
|
| 146 |
+
gr.Markdown(
|
| 147 |
+
"# CrossKEY\n"
|
| 148 |
+
"**3D Cross-modal Keypoint Descriptor for MR-US Matching and Registration**"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
with gr.Tabs():
|
| 152 |
+
# ---- Tab 1: Explore ----
|
| 153 |
+
with gr.Tab("Explore"):
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column(scale=1, min_width=260):
|
| 156 |
+
gr.Markdown("### Matching Parameters")
|
| 157 |
+
demo_ratio = gr.Slider(
|
| 158 |
+
minimum=0.5, maximum=1.0, value=0.75, step=0.05,
|
| 159 |
+
label="Ratio Threshold",
|
| 160 |
+
)
|
| 161 |
+
demo_eval_thresh = gr.Slider(
|
| 162 |
+
minimum=1.0, maximum=10.0, value=5.0, step=0.5,
|
| 163 |
+
label="Evaluation Threshold (mm)",
|
| 164 |
+
)
|
| 165 |
+
demo_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor")
|
| 166 |
+
demo_metric = gr.Dropdown(
|
| 167 |
+
choices=["euclidean", "cosine"], value="euclidean",
|
| 168 |
+
label="Distance Metric",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
gr.Markdown("### Results")
|
| 172 |
+
with gr.Group():
|
| 173 |
+
with gr.Row():
|
| 174 |
+
demo_n_matches = gr.Number(label="Matches", interactive=False)
|
| 175 |
+
demo_n_correct = gr.Number(label="Correct", interactive=False)
|
| 176 |
+
with gr.Row():
|
| 177 |
+
demo_precision = gr.Number(label="Precision (%)", interactive=False)
|
| 178 |
+
demo_match_score = gr.Number(label="Match Score", interactive=False)
|
| 179 |
+
|
| 180 |
+
with gr.Column(scale=3):
|
| 181 |
+
demo_plot = gr.Plot(label="3D Matching Visualization")
|
| 182 |
+
|
| 183 |
+
demo_inputs = [demo_ratio, demo_eval_thresh, demo_mutual, demo_metric]
|
| 184 |
+
demo_outputs = [demo_plot, demo_n_matches, demo_n_correct, demo_precision, demo_match_score]
|
| 185 |
+
|
| 186 |
+
# Update on any parameter change
|
| 187 |
+
for inp in demo_inputs:
|
| 188 |
+
inp.change(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs)
|
| 189 |
+
|
| 190 |
+
# Load initial results
|
| 191 |
+
demo.load(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs)
|
| 192 |
+
|
| 193 |
+
# ---- Tab 2: Your Data ----
|
| 194 |
+
with gr.Tab("Your Data"):
|
| 195 |
+
gr.Markdown(
|
| 196 |
+
"Upload your own MR volume, US volume, heatmap, and a trained CrossKEY checkpoint.\n\n"
|
| 197 |
+
"Inference runs on GPU and may take 30-60 seconds."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
with gr.Row():
|
| 201 |
+
custom_mr = gr.File(label="MR Volume (.nii.gz)", file_types=[".nii.gz"])
|
| 202 |
+
custom_us = gr.File(label="US Volume (.nii.gz)", file_types=[".nii.gz"])
|
| 203 |
+
with gr.Row():
|
| 204 |
+
custom_heatmap = gr.File(label="Heatmap (.nii.gz)", file_types=[".nii.gz"])
|
| 205 |
+
custom_ckpt = gr.File(label="Checkpoint (.ckpt)", file_types=[".ckpt"])
|
| 206 |
+
|
| 207 |
+
custom_run_btn = gr.Button("Run Inference (GPU)", variant="primary")
|
| 208 |
+
|
| 209 |
+
with gr.Row():
|
| 210 |
+
with gr.Column(scale=1, min_width=260):
|
| 211 |
+
gr.Markdown("### Matching Parameters")
|
| 212 |
+
custom_ratio = gr.Slider(
|
| 213 |
+
minimum=0.5, maximum=1.0, value=0.75, step=0.05,
|
| 214 |
+
label="Ratio Threshold",
|
| 215 |
+
)
|
| 216 |
+
custom_eval_thresh = gr.Slider(
|
| 217 |
+
minimum=1.0, maximum=10.0, value=5.0, step=0.5,
|
| 218 |
+
label="Evaluation Threshold (mm)",
|
| 219 |
+
)
|
| 220 |
+
custom_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor")
|
| 221 |
+
custom_metric = gr.Dropdown(
|
| 222 |
+
choices=["euclidean", "cosine"], value="euclidean",
|
| 223 |
+
label="Distance Metric",
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
gr.Markdown("### Results")
|
| 227 |
+
with gr.Group():
|
| 228 |
+
with gr.Row():
|
| 229 |
+
custom_n_matches = gr.Number(label="Matches", interactive=False)
|
| 230 |
+
custom_n_correct = gr.Number(label="Correct", interactive=False)
|
| 231 |
+
with gr.Row():
|
| 232 |
+
custom_precision = gr.Number(label="Precision (%)", interactive=False)
|
| 233 |
+
custom_match_score = gr.Number(label="Match Score", interactive=False)
|
| 234 |
+
|
| 235 |
+
with gr.Column(scale=3):
|
| 236 |
+
custom_plot = gr.Plot(label="3D Matching Visualization")
|
| 237 |
+
|
| 238 |
+
# State to hold inference results
|
| 239 |
+
custom_data_state = gr.State(value=None)
|
| 240 |
+
|
| 241 |
+
custom_param_inputs = [custom_ratio, custom_eval_thresh, custom_mutual, custom_metric]
|
| 242 |
+
custom_outputs = [custom_plot, custom_n_matches, custom_n_correct, custom_precision, custom_match_score]
|
| 243 |
+
|
| 244 |
+
# Inference button: run model, then update visualization
|
| 245 |
+
def infer_and_display(mr_file, us_file, heatmap_file, ckpt_file, ratio, eval_thresh, mutual, metric):
|
| 246 |
+
data = run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file)
|
| 247 |
+
fig, n_m, n_c, prec, ms = update_custom(data, ratio, eval_thresh, mutual, metric)
|
| 248 |
+
return data, fig, n_m, n_c, prec, ms
|
| 249 |
+
|
| 250 |
+
custom_run_btn.click(
|
| 251 |
+
fn=infer_and_display,
|
| 252 |
+
inputs=[custom_mr, custom_us, custom_heatmap, custom_ckpt] + custom_param_inputs,
|
| 253 |
+
outputs=[custom_data_state] + custom_outputs,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Re-match on parameter change (no re-inference)
|
| 257 |
+
for inp in custom_param_inputs:
|
| 258 |
+
inp.change(
|
| 259 |
+
fn=update_custom,
|
| 260 |
+
inputs=[custom_data_state] + custom_param_inputs,
|
| 261 |
+
outputs=custom_outputs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
demo.launch(
|
| 267 |
+
theme=gr.themes.Soft(),
|
| 268 |
+
css="footer {display: none !important;}",
|
| 269 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Descriptor extraction and matching for CrossKEY HF Space.
|
| 2 |
+
|
| 3 |
+
Provides functions for:
|
| 4 |
+
1. Re-running KNN matching with new parameters (CPU, fast)
|
| 5 |
+
2. Full inference from uploaded volumes + checkpoint (GPU)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from src.data.datamodule import DescriptorDataModule
|
| 16 |
+
from src.model.descriptor import Descriptor
|
| 17 |
+
from src.model.matcher import KNNMatcher
|
| 18 |
+
from src.utils.utils import load_nifti
|
| 19 |
+
|
| 20 |
+
from visualization import downsample_volume
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_precomputed(precomputed_dir: str = "precomputed") -> dict:
|
| 24 |
+
"""Load all pre-computed data for the default demo tab.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Dict with keys: descriptors_mr, descriptors_us, points_mr, points_us,
|
| 28 |
+
volume_mr, volume_us, metadata
|
| 29 |
+
"""
|
| 30 |
+
d = Path(precomputed_dir)
|
| 31 |
+
with open(d / "metadata.json") as f:
|
| 32 |
+
metadata = json.load(f)
|
| 33 |
+
|
| 34 |
+
return {
|
| 35 |
+
"descriptors_mr": torch.load(d / "descriptors_mr.pt", weights_only=True),
|
| 36 |
+
"descriptors_us": torch.load(d / "descriptors_us.pt", weights_only=True),
|
| 37 |
+
"points_mr": torch.load(d / "points_mr.pt", weights_only=True).numpy(),
|
| 38 |
+
"points_us": torch.load(d / "points_us.pt", weights_only=True).numpy(),
|
| 39 |
+
"volume_mr": np.load(d / "volume_mr.npy"),
|
| 40 |
+
"volume_us": np.load(d / "volume_us.npy"),
|
| 41 |
+
"metadata": metadata,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_matching(
|
| 46 |
+
descriptors_mr: torch.Tensor,
|
| 47 |
+
descriptors_us: torch.Tensor,
|
| 48 |
+
points_mr: np.ndarray,
|
| 49 |
+
points_us: np.ndarray,
|
| 50 |
+
ratio_threshold: float = 0.75,
|
| 51 |
+
mutual: bool = True,
|
| 52 |
+
metric: str = "euclidean",
|
| 53 |
+
evaluation_threshold: float = 5.0,
|
| 54 |
+
) -> Tuple[List[Tuple[int, int, float]], Dict[str, float]]:
|
| 55 |
+
"""Run KNN matching with given parameters. CPU-only, fast (<1s).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
(match_pairs, metrics) -- same format as KNNMatcher.match_and_evaluate()
|
| 59 |
+
"""
|
| 60 |
+
matcher = KNNMatcher(
|
| 61 |
+
k=1,
|
| 62 |
+
distance_threshold=float("inf"),
|
| 63 |
+
ratio_threshold=ratio_threshold,
|
| 64 |
+
mutual=mutual,
|
| 65 |
+
metric=metric,
|
| 66 |
+
evaluation_threshold=evaluation_threshold,
|
| 67 |
+
)
|
| 68 |
+
return matcher.match_and_evaluate(
|
| 69 |
+
descriptors_mr, descriptors_us, points_mr, points_us,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def run_inference(
|
| 74 |
+
mr_path: str,
|
| 75 |
+
us_path: str,
|
| 76 |
+
heatmap_path: str,
|
| 77 |
+
checkpoint_path: str,
|
| 78 |
+
batch_size: int = 64,
|
| 79 |
+
grid_spacing: int = 8,
|
| 80 |
+
) -> dict:
|
| 81 |
+
"""Run full inference on uploaded volumes. Requires GPU.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
mr_path: Path to uploaded MR NIfTI file.
|
| 85 |
+
us_path: Path to uploaded US NIfTI file.
|
| 86 |
+
heatmap_path: Path to uploaded heatmap NIfTI file.
|
| 87 |
+
checkpoint_path: Path to uploaded checkpoint.
|
| 88 |
+
batch_size: Inference batch size.
|
| 89 |
+
grid_spacing: Grid spacing for US keypoint generation.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Dict with same keys as load_precomputed().
|
| 93 |
+
"""
|
| 94 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 95 |
+
|
| 96 |
+
# Load model
|
| 97 |
+
model = Descriptor.load_from_checkpoint(checkpoint_path)
|
| 98 |
+
model.eval()
|
| 99 |
+
model.to(device)
|
| 100 |
+
|
| 101 |
+
# Create datamodule with custom paths
|
| 102 |
+
dm = DescriptorDataModule(
|
| 103 |
+
data_dir=".", # Not used when paths are specified
|
| 104 |
+
batch_size=batch_size,
|
| 105 |
+
num_workers=0,
|
| 106 |
+
patch_size=(32, 32, 32),
|
| 107 |
+
grid_spacing=grid_spacing,
|
| 108 |
+
mr_path=mr_path,
|
| 109 |
+
us_path=us_path,
|
| 110 |
+
heatmap_path=heatmap_path,
|
| 111 |
+
)
|
| 112 |
+
dm.setup(stage="test")
|
| 113 |
+
|
| 114 |
+
# Extract descriptors
|
| 115 |
+
all_desc, all_pts, all_mod = [], [], []
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for batch in dm.test_dataloader():
|
| 118 |
+
desc = model(batch["patch"].to(device))
|
| 119 |
+
all_desc.append(desc.cpu())
|
| 120 |
+
all_pts.append(batch["point"].cpu())
|
| 121 |
+
all_mod.extend(batch["modality"])
|
| 122 |
+
|
| 123 |
+
all_desc = torch.cat(all_desc)
|
| 124 |
+
all_pts = torch.cat(all_pts)
|
| 125 |
+
mr_mask = torch.tensor([m == "mr" for m in all_mod])
|
| 126 |
+
|
| 127 |
+
# Downsample volumes for rendering
|
| 128 |
+
mr_vol = load_nifti(mr_path)
|
| 129 |
+
us_vol = load_nifti(us_path)
|
| 130 |
+
mr_norm = (mr_vol - mr_vol.min()) / (mr_vol.max() - mr_vol.min() + 1e-8)
|
| 131 |
+
us_norm = (us_vol - us_vol.min()) / (us_vol.max() - us_vol.min() + 1e-8)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"descriptors_mr": all_desc[mr_mask],
|
| 135 |
+
"descriptors_us": all_desc[~mr_mask],
|
| 136 |
+
"points_mr": all_pts[mr_mask].numpy(),
|
| 137 |
+
"points_us": all_pts[~mr_mask].numpy(),
|
| 138 |
+
"volume_mr": downsample_volume(mr_norm),
|
| 139 |
+
"volume_us": downsample_volume(us_norm),
|
| 140 |
+
"metadata": {
|
| 141 |
+
"padded_shape_mr": list(dm._mr_volume.shape),
|
| 142 |
+
"padded_shape_us": list(dm._us_volume.shape),
|
| 143 |
+
},
|
| 144 |
+
}
|
precomputed/descriptors_mr.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9507f55c6c0666346b8ab93a954c89e8df192d937fbc18452a2eacc5b47e1fd2
|
| 3 |
+
size 2098778
|
precomputed/descriptors_us.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:393b7b33a75392322615d4523c1b4b6ff7b6664491b2d060aef9df95e317952e
|
| 3 |
+
size 4443738
|
precomputed/metadata.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"padded_shape_mr": [
|
| 3 |
+
224,
|
| 4 |
+
224,
|
| 5 |
+
198
|
| 6 |
+
],
|
| 7 |
+
"padded_shape_us": [
|
| 8 |
+
224,
|
| 9 |
+
224,
|
| 10 |
+
198
|
| 11 |
+
],
|
| 12 |
+
"num_mr_descriptors": 1024,
|
| 13 |
+
"num_us_descriptors": 2169,
|
| 14 |
+
"descriptor_dim": 512,
|
| 15 |
+
"grid_spacing": 8
|
| 16 |
+
}
|
precomputed/points_mr.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:86dee258740fe89673a551a7a1611f39b9566c5e6071521f7424ff0a71ecec54
|
| 3 |
+
size 13879
|
precomputed/points_us.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6b8d0ece568a4552e6f7f88c37d688ffeb2daf8ab559baaab128e826d28e0f2
|
| 3 |
+
size 27575
|
precomputed/volume_mr.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:feb4d465b3e378693f731a27e0d53544db16ddac9e5e507b5ca073f6531e1bf0
|
| 3 |
+
size 1048704
|
precomputed/volume_us.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42fb079ee925a7a71d8c76a41c5a69963a67c2b7e585e747d7414c4a17f0162c
|
| 3 |
+
size 1048704
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
crosskey @ git+https://github.com/morozovdd/CrossKEY.git
|
| 2 |
+
gradio>=5.0,<6.0
|
| 3 |
+
plotly>=6.0,<7.0
|
| 4 |
+
scikit-image>=0.24,<1.0
|
visualization.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plotly 3D visualization for CrossKEY matching results.
|
| 2 |
+
|
| 3 |
+
Builds side-by-side volume isosurfaces with keypoints and match lines.
|
| 4 |
+
MR volume on the left, US volume on the right, offset along the X axis.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
from scipy.ndimage import zoom
|
| 10 |
+
from skimage.measure import marching_cubes
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def downsample_volume(volume: np.ndarray, target_size: int = 64) -> np.ndarray:
|
| 14 |
+
"""Downsample volume to target_size^3 for browser-friendly rendering."""
|
| 15 |
+
factors = [target_size / s for s in volume.shape]
|
| 16 |
+
return zoom(volume, factors, order=1).astype(np.float32)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def scale_points(
|
| 20 |
+
points: np.ndarray,
|
| 21 |
+
padded_shape: tuple,
|
| 22 |
+
volume_shape: tuple,
|
| 23 |
+
) -> np.ndarray:
|
| 24 |
+
"""Scale point coordinates from padded volume space to downsampled volume space."""
|
| 25 |
+
scale = np.array(volume_shape, dtype=float) / np.array(padded_shape, dtype=float)
|
| 26 |
+
return points * scale
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_isosurface_trace(
|
| 30 |
+
volume: np.ndarray,
|
| 31 |
+
level: float,
|
| 32 |
+
colorscale: str = "Gray",
|
| 33 |
+
opacity: float = 0.15,
|
| 34 |
+
name: str = "",
|
| 35 |
+
offset_x: float = 0.0,
|
| 36 |
+
) -> go.Mesh3d:
|
| 37 |
+
"""Create a Mesh3d trace from a volume isosurface via marching cubes.
|
| 38 |
+
|
| 39 |
+
Uses vertex intensity from the original volume for natural coloring.
|
| 40 |
+
"""
|
| 41 |
+
verts, faces, _, _ = marching_cubes(volume, level=level)
|
| 42 |
+
# Sample volume intensity at each vertex for natural coloring
|
| 43 |
+
vi = np.clip(verts.astype(int), 0, np.array(volume.shape) - 1)
|
| 44 |
+
intensities = volume[vi[:, 0], vi[:, 1], vi[:, 2]]
|
| 45 |
+
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 so cone points up
|
| 46 |
+
return go.Mesh3d(
|
| 47 |
+
x=verts[:, 1] + offset_x,
|
| 48 |
+
y=verts[:, 2],
|
| 49 |
+
z=-verts[:, 0],
|
| 50 |
+
i=faces[:, 0],
|
| 51 |
+
j=faces[:, 1],
|
| 52 |
+
k=faces[:, 2],
|
| 53 |
+
intensity=intensities,
|
| 54 |
+
colorscale=colorscale,
|
| 55 |
+
opacity=opacity,
|
| 56 |
+
name=name,
|
| 57 |
+
showlegend=True,
|
| 58 |
+
showscale=False,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_keypoint_trace(
|
| 63 |
+
points: np.ndarray,
|
| 64 |
+
color: str,
|
| 65 |
+
size: float = 3.0,
|
| 66 |
+
opacity: float = 1.0,
|
| 67 |
+
name: str = "",
|
| 68 |
+
offset_x: float = 0.0,
|
| 69 |
+
) -> go.Scatter3d:
|
| 70 |
+
"""Create Scatter3d markers for keypoints."""
|
| 71 |
+
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
|
| 72 |
+
return go.Scatter3d(
|
| 73 |
+
x=points[:, 1] + offset_x,
|
| 74 |
+
y=points[:, 2],
|
| 75 |
+
z=-points[:, 0],
|
| 76 |
+
mode="markers",
|
| 77 |
+
marker=dict(size=size, color=color, opacity=opacity),
|
| 78 |
+
name=name,
|
| 79 |
+
showlegend=True,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def create_match_lines(
|
| 84 |
+
src_pts: np.ndarray,
|
| 85 |
+
tgt_pts: np.ndarray,
|
| 86 |
+
color: str,
|
| 87 |
+
width: float = 2.0,
|
| 88 |
+
name: str = "",
|
| 89 |
+
offset_x: float = 0.0,
|
| 90 |
+
) -> go.Scatter3d:
|
| 91 |
+
"""Create lines connecting matched source points to offset target points."""
|
| 92 |
+
# Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
|
| 93 |
+
lx, ly, lz = [], [], []
|
| 94 |
+
for s, t in zip(src_pts, tgt_pts):
|
| 95 |
+
lx.extend([float(s[1]), float(t[1]) + offset_x, None])
|
| 96 |
+
ly.extend([float(s[2]), float(t[2]), None])
|
| 97 |
+
lz.extend([-float(s[0]), -float(t[0]), None])
|
| 98 |
+
return go.Scatter3d(
|
| 99 |
+
x=lx, y=ly, z=lz,
|
| 100 |
+
mode="lines",
|
| 101 |
+
line=dict(color=color, width=width),
|
| 102 |
+
name=name,
|
| 103 |
+
showlegend=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_matching_figure(
|
| 108 |
+
volume_mr: np.ndarray,
|
| 109 |
+
volume_us: np.ndarray,
|
| 110 |
+
points_mr: np.ndarray,
|
| 111 |
+
points_us: np.ndarray,
|
| 112 |
+
padded_shape_mr: tuple,
|
| 113 |
+
padded_shape_us: tuple,
|
| 114 |
+
match_pairs: list,
|
| 115 |
+
metrics: dict,
|
| 116 |
+
evaluation_threshold: float = 5.0,
|
| 117 |
+
mr_level: float = 0.3,
|
| 118 |
+
us_level: float = 0.1,
|
| 119 |
+
) -> go.Figure:
|
| 120 |
+
"""Build the full 3D matching visualization."""
|
| 121 |
+
fig = go.Figure()
|
| 122 |
+
|
| 123 |
+
# Scale keypoints to match downsampled volume coordinates
|
| 124 |
+
pts_mr_viz = scale_points(points_mr, padded_shape_mr, volume_mr.shape)
|
| 125 |
+
pts_us_viz = scale_points(points_us, padded_shape_us, volume_us.shape)
|
| 126 |
+
|
| 127 |
+
# Side-by-side offset along Plotly x (= data axis 1)
|
| 128 |
+
gap = volume_mr.shape[1] * 0.3
|
| 129 |
+
offset_x = volume_mr.shape[1] + gap
|
| 130 |
+
|
| 131 |
+
# Volume isosurfaces with natural intensity coloring
|
| 132 |
+
try:
|
| 133 |
+
fig.add_trace(create_isosurface_trace(
|
| 134 |
+
volume_mr, level=mr_level, colorscale="Gray",
|
| 135 |
+
opacity=0.15, name="MR Surface",
|
| 136 |
+
))
|
| 137 |
+
except ValueError:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
fig.add_trace(create_isosurface_trace(
|
| 142 |
+
volume_us, level=us_level, colorscale="Hot",
|
| 143 |
+
opacity=0.15, name="US Surface", offset_x=offset_x,
|
| 144 |
+
))
|
| 145 |
+
except ValueError:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
# Process matches
|
| 149 |
+
src_indices = [p[0] for p in match_pairs]
|
| 150 |
+
tgt_indices = [p[1] for p in match_pairs]
|
| 151 |
+
|
| 152 |
+
if match_pairs:
|
| 153 |
+
mr_matched = points_mr[src_indices]
|
| 154 |
+
us_matched = points_us[tgt_indices]
|
| 155 |
+
spatial_dist = np.linalg.norm(mr_matched - us_matched, axis=1)
|
| 156 |
+
correct = spatial_dist < evaluation_threshold
|
| 157 |
+
|
| 158 |
+
mr_matched_viz = pts_mr_viz[src_indices]
|
| 159 |
+
us_matched_viz = pts_us_viz[tgt_indices]
|
| 160 |
+
|
| 161 |
+
if correct.any():
|
| 162 |
+
fig.add_trace(create_match_lines(
|
| 163 |
+
mr_matched_viz[correct], us_matched_viz[correct],
|
| 164 |
+
color="rgba(0,200,0,0.6)", width=2,
|
| 165 |
+
name=f"Correct ({correct.sum()})", offset_x=offset_x,
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
if (~correct).any():
|
| 169 |
+
fig.add_trace(create_match_lines(
|
| 170 |
+
mr_matched_viz[~correct], us_matched_viz[~correct],
|
| 171 |
+
color="rgba(255,0,0,0.3)", width=1,
|
| 172 |
+
name=f"Incorrect ({(~correct).sum()})", offset_x=offset_x,
|
| 173 |
+
))
|
| 174 |
+
|
| 175 |
+
fig.add_trace(create_keypoint_trace(
|
| 176 |
+
mr_matched_viz, color="royalblue", size=4,
|
| 177 |
+
name=f"MR Matched ({len(mr_matched_viz)})",
|
| 178 |
+
))
|
| 179 |
+
fig.add_trace(create_keypoint_trace(
|
| 180 |
+
us_matched_viz, color="crimson", size=4,
|
| 181 |
+
name=f"US Matched ({len(us_matched_viz)})", offset_x=offset_x,
|
| 182 |
+
))
|
| 183 |
+
|
| 184 |
+
# Unmatched keypoints (faded)
|
| 185 |
+
matched_mr_set = set(src_indices)
|
| 186 |
+
matched_us_set = set(tgt_indices)
|
| 187 |
+
unmatched_mr = np.array([i not in matched_mr_set for i in range(len(pts_mr_viz))])
|
| 188 |
+
unmatched_us = np.array([i not in matched_us_set for i in range(len(pts_us_viz))])
|
| 189 |
+
|
| 190 |
+
if unmatched_mr.any():
|
| 191 |
+
fig.add_trace(create_keypoint_trace(
|
| 192 |
+
pts_mr_viz[unmatched_mr], color="royalblue",
|
| 193 |
+
size=1.5, opacity=0.2, name="MR Unmatched",
|
| 194 |
+
))
|
| 195 |
+
if unmatched_us.any():
|
| 196 |
+
fig.add_trace(create_keypoint_trace(
|
| 197 |
+
pts_us_viz[unmatched_us], color="crimson",
|
| 198 |
+
size=1.5, opacity=0.2, name="US Unmatched", offset_x=offset_x,
|
| 199 |
+
))
|
| 200 |
+
|
| 201 |
+
# Layout -- no fixed width so Plotly fills the Gradio container
|
| 202 |
+
fig.update_layout(
|
| 203 |
+
scene=dict(
|
| 204 |
+
xaxis=dict(visible=False),
|
| 205 |
+
yaxis=dict(visible=False),
|
| 206 |
+
zaxis=dict(visible=False),
|
| 207 |
+
aspectmode="data",
|
| 208 |
+
camera=dict(
|
| 209 |
+
up=dict(x=0, y=0, z=1),
|
| 210 |
+
eye=dict(x=0, y=-1.8, z=0.3),
|
| 211 |
+
),
|
| 212 |
+
),
|
| 213 |
+
height=700,
|
| 214 |
+
margin=dict(l=0, r=0, t=40, b=0),
|
| 215 |
+
legend=dict(
|
| 216 |
+
yanchor="top", y=0.99,
|
| 217 |
+
xanchor="left", x=0.01,
|
| 218 |
+
bgcolor="rgba(0,0,0,0.5)",
|
| 219 |
+
font=dict(color="white"),
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
return fig
|