morozovdd Claude Opus 4.6 (1M context) commited on
Commit
ffbfad7
·
1 Parent(s): 1ceb67f

Initial deploy: CrossKEY interactive 3D matching demo

Browse files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

README.md CHANGED
@@ -1,12 +1,17 @@
1
  ---
2
  title: CrossKEY
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.12.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: CrossKEY
3
+ emoji: "\U0001f9e0"
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: "5.29.0"
8
  app_file: app.py
9
+ license: mit
10
+ short_description: 3D Cross-modal Keypoint Matching (MR-US)
11
  ---
12
 
13
+ # CrossKEY Demo
14
+
15
+ Interactive 3D visualization of cross-modal keypoint matching between MRI and intraoperative ultrasound.
16
+
17
+ [Paper (IEEE TMI)](https://doi.org/10.1109/TMI.2026.3680352) | [Code (GitHub)](https://github.com/morozovdd/CrossKEY) | [arXiv](https://arxiv.org/abs/2507.18551)
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CrossKEY HuggingFace Space -- Interactive 3D Keypoint Matching Demo.
2
+
3
+ Two-tab Gradio app:
4
+ Tab 1 (Explore): Pre-computed results with adjustable matching parameters.
5
+ Tab 2 (Your Data): Upload volumes + checkpoint for live inference.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ # Add space/ to path so local imports work both locally and on HF
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+
18
+ from inference import load_precomputed, run_inference, run_matching
19
+ from visualization import build_matching_figure
20
+
21
+ # ZeroGPU decorator -- no-op when running locally
22
+ try:
23
+ import spaces
24
+ gpu_decorator = spaces.GPU
25
+ except ImportError:
26
+ gpu_decorator = lambda fn: fn
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
29
+ logger = logging.getLogger("crosskey.app")
30
+
31
+ # -- Load pre-computed data on startup --
32
+ logger.info("Loading pre-computed demo data...")
33
+ DEMO_DATA = load_precomputed(
34
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "precomputed")
35
+ )
36
+ logger.info(
37
+ "Loaded: %d MR descriptors, %d US descriptors",
38
+ len(DEMO_DATA["descriptors_mr"]),
39
+ len(DEMO_DATA["descriptors_us"]),
40
+ )
41
+
42
+
43
+ def update_demo(
44
+ ratio_threshold: float,
45
+ evaluation_threshold: float,
46
+ mutual: bool,
47
+ metric: str,
48
+ ) -> tuple:
49
+ """Re-run matching with new parameters and rebuild the figure."""
50
+ match_pairs, metrics = run_matching(
51
+ DEMO_DATA["descriptors_mr"],
52
+ DEMO_DATA["descriptors_us"],
53
+ DEMO_DATA["points_mr"],
54
+ DEMO_DATA["points_us"],
55
+ ratio_threshold=ratio_threshold,
56
+ mutual=mutual,
57
+ metric=metric,
58
+ evaluation_threshold=evaluation_threshold,
59
+ )
60
+
61
+ fig = build_matching_figure(
62
+ volume_mr=DEMO_DATA["volume_mr"],
63
+ volume_us=DEMO_DATA["volume_us"],
64
+ points_mr=DEMO_DATA["points_mr"],
65
+ points_us=DEMO_DATA["points_us"],
66
+ padded_shape_mr=tuple(DEMO_DATA["metadata"]["padded_shape_mr"]),
67
+ padded_shape_us=tuple(DEMO_DATA["metadata"]["padded_shape_us"]),
68
+ match_pairs=match_pairs,
69
+ metrics=metrics,
70
+ evaluation_threshold=evaluation_threshold,
71
+ )
72
+
73
+ return (
74
+ fig,
75
+ metrics['num_matches'],
76
+ metrics['num_correct'],
77
+ round(metrics['precision'], 1),
78
+ round(metrics['matching_score'], 4),
79
+ )
80
+
81
+
82
+ @gpu_decorator
83
+ def run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file):
84
+ """Run inference on uploaded data. Uses ZeroGPU on HF Spaces."""
85
+ if any(f is None for f in [mr_file, us_file, heatmap_file, ckpt_file]):
86
+ raise gr.Error("Please upload all four files: MR volume, US volume, heatmap, and checkpoint.")
87
+
88
+ logger.info("Running inference on uploaded data...")
89
+ data = run_inference(
90
+ mr_path=mr_file.name,
91
+ us_path=us_file.name,
92
+ heatmap_path=heatmap_file.name,
93
+ checkpoint_path=ckpt_file.name,
94
+ )
95
+ return data
96
+
97
+
98
+ def update_custom(
99
+ data: dict,
100
+ ratio_threshold: float,
101
+ evaluation_threshold: float,
102
+ mutual: bool,
103
+ metric: str,
104
+ ) -> tuple:
105
+ """Re-run matching on custom data with new parameters."""
106
+ if data is None:
107
+ raise gr.Error("Run inference first.")
108
+
109
+ match_pairs, metrics = run_matching(
110
+ data["descriptors_mr"],
111
+ data["descriptors_us"],
112
+ data["points_mr"],
113
+ data["points_us"],
114
+ ratio_threshold=ratio_threshold,
115
+ mutual=mutual,
116
+ metric=metric,
117
+ evaluation_threshold=evaluation_threshold,
118
+ )
119
+
120
+ fig = build_matching_figure(
121
+ volume_mr=data["volume_mr"],
122
+ volume_us=data["volume_us"],
123
+ points_mr=data["points_mr"],
124
+ points_us=data["points_us"],
125
+ padded_shape_mr=tuple(data["metadata"]["padded_shape_mr"]),
126
+ padded_shape_us=tuple(data["metadata"]["padded_shape_us"]),
127
+ match_pairs=match_pairs,
128
+ metrics=metrics,
129
+ evaluation_threshold=evaluation_threshold,
130
+ )
131
+
132
+ return (
133
+ fig,
134
+ metrics['num_matches'],
135
+ metrics['num_correct'],
136
+ round(metrics['precision'], 1),
137
+ round(metrics['matching_score'], 4),
138
+ )
139
+
140
+
141
+ # -- Build Gradio UI --
142
+
143
+ with gr.Blocks(
144
+ title="CrossKEY -- 3D Cross-modal Keypoint Matching",
145
+ ) as demo:
146
+ gr.Markdown(
147
+ "# CrossKEY\n"
148
+ "**3D Cross-modal Keypoint Descriptor for MR-US Matching and Registration**"
149
+ )
150
+
151
+ with gr.Tabs():
152
+ # ---- Tab 1: Explore ----
153
+ with gr.Tab("Explore"):
154
+ with gr.Row():
155
+ with gr.Column(scale=1, min_width=260):
156
+ gr.Markdown("### Matching Parameters")
157
+ demo_ratio = gr.Slider(
158
+ minimum=0.5, maximum=1.0, value=0.75, step=0.05,
159
+ label="Ratio Threshold",
160
+ )
161
+ demo_eval_thresh = gr.Slider(
162
+ minimum=1.0, maximum=10.0, value=5.0, step=0.5,
163
+ label="Evaluation Threshold (mm)",
164
+ )
165
+ demo_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor")
166
+ demo_metric = gr.Dropdown(
167
+ choices=["euclidean", "cosine"], value="euclidean",
168
+ label="Distance Metric",
169
+ )
170
+
171
+ gr.Markdown("### Results")
172
+ with gr.Group():
173
+ with gr.Row():
174
+ demo_n_matches = gr.Number(label="Matches", interactive=False)
175
+ demo_n_correct = gr.Number(label="Correct", interactive=False)
176
+ with gr.Row():
177
+ demo_precision = gr.Number(label="Precision (%)", interactive=False)
178
+ demo_match_score = gr.Number(label="Match Score", interactive=False)
179
+
180
+ with gr.Column(scale=3):
181
+ demo_plot = gr.Plot(label="3D Matching Visualization")
182
+
183
+ demo_inputs = [demo_ratio, demo_eval_thresh, demo_mutual, demo_metric]
184
+ demo_outputs = [demo_plot, demo_n_matches, demo_n_correct, demo_precision, demo_match_score]
185
+
186
+ # Update on any parameter change
187
+ for inp in demo_inputs:
188
+ inp.change(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs)
189
+
190
+ # Load initial results
191
+ demo.load(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs)
192
+
193
+ # ---- Tab 2: Your Data ----
194
+ with gr.Tab("Your Data"):
195
+ gr.Markdown(
196
+ "Upload your own MR volume, US volume, heatmap, and a trained CrossKEY checkpoint.\n\n"
197
+ "Inference runs on GPU and may take 30-60 seconds."
198
+ )
199
+
200
+ with gr.Row():
201
+ custom_mr = gr.File(label="MR Volume (.nii.gz)", file_types=[".nii.gz"])
202
+ custom_us = gr.File(label="US Volume (.nii.gz)", file_types=[".nii.gz"])
203
+ with gr.Row():
204
+ custom_heatmap = gr.File(label="Heatmap (.nii.gz)", file_types=[".nii.gz"])
205
+ custom_ckpt = gr.File(label="Checkpoint (.ckpt)", file_types=[".ckpt"])
206
+
207
+ custom_run_btn = gr.Button("Run Inference (GPU)", variant="primary")
208
+
209
+ with gr.Row():
210
+ with gr.Column(scale=1, min_width=260):
211
+ gr.Markdown("### Matching Parameters")
212
+ custom_ratio = gr.Slider(
213
+ minimum=0.5, maximum=1.0, value=0.75, step=0.05,
214
+ label="Ratio Threshold",
215
+ )
216
+ custom_eval_thresh = gr.Slider(
217
+ minimum=1.0, maximum=10.0, value=5.0, step=0.5,
218
+ label="Evaluation Threshold (mm)",
219
+ )
220
+ custom_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor")
221
+ custom_metric = gr.Dropdown(
222
+ choices=["euclidean", "cosine"], value="euclidean",
223
+ label="Distance Metric",
224
+ )
225
+
226
+ gr.Markdown("### Results")
227
+ with gr.Group():
228
+ with gr.Row():
229
+ custom_n_matches = gr.Number(label="Matches", interactive=False)
230
+ custom_n_correct = gr.Number(label="Correct", interactive=False)
231
+ with gr.Row():
232
+ custom_precision = gr.Number(label="Precision (%)", interactive=False)
233
+ custom_match_score = gr.Number(label="Match Score", interactive=False)
234
+
235
+ with gr.Column(scale=3):
236
+ custom_plot = gr.Plot(label="3D Matching Visualization")
237
+
238
+ # State to hold inference results
239
+ custom_data_state = gr.State(value=None)
240
+
241
+ custom_param_inputs = [custom_ratio, custom_eval_thresh, custom_mutual, custom_metric]
242
+ custom_outputs = [custom_plot, custom_n_matches, custom_n_correct, custom_precision, custom_match_score]
243
+
244
+ # Inference button: run model, then update visualization
245
+ def infer_and_display(mr_file, us_file, heatmap_file, ckpt_file, ratio, eval_thresh, mutual, metric):
246
+ data = run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file)
247
+ fig, n_m, n_c, prec, ms = update_custom(data, ratio, eval_thresh, mutual, metric)
248
+ return data, fig, n_m, n_c, prec, ms
249
+
250
+ custom_run_btn.click(
251
+ fn=infer_and_display,
252
+ inputs=[custom_mr, custom_us, custom_heatmap, custom_ckpt] + custom_param_inputs,
253
+ outputs=[custom_data_state] + custom_outputs,
254
+ )
255
+
256
+ # Re-match on parameter change (no re-inference)
257
+ for inp in custom_param_inputs:
258
+ inp.change(
259
+ fn=update_custom,
260
+ inputs=[custom_data_state] + custom_param_inputs,
261
+ outputs=custom_outputs,
262
+ )
263
+
264
+
265
+ if __name__ == "__main__":
266
+ demo.launch(
267
+ theme=gr.themes.Soft(),
268
+ css="footer {display: none !important;}",
269
+ )
inference.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Descriptor extraction and matching for CrossKEY HF Space.
2
+
3
+ Provides functions for:
4
+ 1. Re-running KNN matching with new parameters (CPU, fast)
5
+ 2. Full inference from uploaded volumes + checkpoint (GPU)
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from src.data.datamodule import DescriptorDataModule
16
+ from src.model.descriptor import Descriptor
17
+ from src.model.matcher import KNNMatcher
18
+ from src.utils.utils import load_nifti
19
+
20
+ from visualization import downsample_volume
21
+
22
+
23
+ def load_precomputed(precomputed_dir: str = "precomputed") -> dict:
24
+ """Load all pre-computed data for the default demo tab.
25
+
26
+ Returns:
27
+ Dict with keys: descriptors_mr, descriptors_us, points_mr, points_us,
28
+ volume_mr, volume_us, metadata
29
+ """
30
+ d = Path(precomputed_dir)
31
+ with open(d / "metadata.json") as f:
32
+ metadata = json.load(f)
33
+
34
+ return {
35
+ "descriptors_mr": torch.load(d / "descriptors_mr.pt", weights_only=True),
36
+ "descriptors_us": torch.load(d / "descriptors_us.pt", weights_only=True),
37
+ "points_mr": torch.load(d / "points_mr.pt", weights_only=True).numpy(),
38
+ "points_us": torch.load(d / "points_us.pt", weights_only=True).numpy(),
39
+ "volume_mr": np.load(d / "volume_mr.npy"),
40
+ "volume_us": np.load(d / "volume_us.npy"),
41
+ "metadata": metadata,
42
+ }
43
+
44
+
45
+ def run_matching(
46
+ descriptors_mr: torch.Tensor,
47
+ descriptors_us: torch.Tensor,
48
+ points_mr: np.ndarray,
49
+ points_us: np.ndarray,
50
+ ratio_threshold: float = 0.75,
51
+ mutual: bool = True,
52
+ metric: str = "euclidean",
53
+ evaluation_threshold: float = 5.0,
54
+ ) -> Tuple[List[Tuple[int, int, float]], Dict[str, float]]:
55
+ """Run KNN matching with given parameters. CPU-only, fast (<1s).
56
+
57
+ Returns:
58
+ (match_pairs, metrics) -- same format as KNNMatcher.match_and_evaluate()
59
+ """
60
+ matcher = KNNMatcher(
61
+ k=1,
62
+ distance_threshold=float("inf"),
63
+ ratio_threshold=ratio_threshold,
64
+ mutual=mutual,
65
+ metric=metric,
66
+ evaluation_threshold=evaluation_threshold,
67
+ )
68
+ return matcher.match_and_evaluate(
69
+ descriptors_mr, descriptors_us, points_mr, points_us,
70
+ )
71
+
72
+
73
+ def run_inference(
74
+ mr_path: str,
75
+ us_path: str,
76
+ heatmap_path: str,
77
+ checkpoint_path: str,
78
+ batch_size: int = 64,
79
+ grid_spacing: int = 8,
80
+ ) -> dict:
81
+ """Run full inference on uploaded volumes. Requires GPU.
82
+
83
+ Args:
84
+ mr_path: Path to uploaded MR NIfTI file.
85
+ us_path: Path to uploaded US NIfTI file.
86
+ heatmap_path: Path to uploaded heatmap NIfTI file.
87
+ checkpoint_path: Path to uploaded checkpoint.
88
+ batch_size: Inference batch size.
89
+ grid_spacing: Grid spacing for US keypoint generation.
90
+
91
+ Returns:
92
+ Dict with same keys as load_precomputed().
93
+ """
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+
96
+ # Load model
97
+ model = Descriptor.load_from_checkpoint(checkpoint_path)
98
+ model.eval()
99
+ model.to(device)
100
+
101
+ # Create datamodule with custom paths
102
+ dm = DescriptorDataModule(
103
+ data_dir=".", # Not used when paths are specified
104
+ batch_size=batch_size,
105
+ num_workers=0,
106
+ patch_size=(32, 32, 32),
107
+ grid_spacing=grid_spacing,
108
+ mr_path=mr_path,
109
+ us_path=us_path,
110
+ heatmap_path=heatmap_path,
111
+ )
112
+ dm.setup(stage="test")
113
+
114
+ # Extract descriptors
115
+ all_desc, all_pts, all_mod = [], [], []
116
+ with torch.no_grad():
117
+ for batch in dm.test_dataloader():
118
+ desc = model(batch["patch"].to(device))
119
+ all_desc.append(desc.cpu())
120
+ all_pts.append(batch["point"].cpu())
121
+ all_mod.extend(batch["modality"])
122
+
123
+ all_desc = torch.cat(all_desc)
124
+ all_pts = torch.cat(all_pts)
125
+ mr_mask = torch.tensor([m == "mr" for m in all_mod])
126
+
127
+ # Downsample volumes for rendering
128
+ mr_vol = load_nifti(mr_path)
129
+ us_vol = load_nifti(us_path)
130
+ mr_norm = (mr_vol - mr_vol.min()) / (mr_vol.max() - mr_vol.min() + 1e-8)
131
+ us_norm = (us_vol - us_vol.min()) / (us_vol.max() - us_vol.min() + 1e-8)
132
+
133
+ return {
134
+ "descriptors_mr": all_desc[mr_mask],
135
+ "descriptors_us": all_desc[~mr_mask],
136
+ "points_mr": all_pts[mr_mask].numpy(),
137
+ "points_us": all_pts[~mr_mask].numpy(),
138
+ "volume_mr": downsample_volume(mr_norm),
139
+ "volume_us": downsample_volume(us_norm),
140
+ "metadata": {
141
+ "padded_shape_mr": list(dm._mr_volume.shape),
142
+ "padded_shape_us": list(dm._us_volume.shape),
143
+ },
144
+ }
precomputed/descriptors_mr.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9507f55c6c0666346b8ab93a954c89e8df192d937fbc18452a2eacc5b47e1fd2
3
+ size 2098778
precomputed/descriptors_us.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:393b7b33a75392322615d4523c1b4b6ff7b6664491b2d060aef9df95e317952e
3
+ size 4443738
precomputed/metadata.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "padded_shape_mr": [
3
+ 224,
4
+ 224,
5
+ 198
6
+ ],
7
+ "padded_shape_us": [
8
+ 224,
9
+ 224,
10
+ 198
11
+ ],
12
+ "num_mr_descriptors": 1024,
13
+ "num_us_descriptors": 2169,
14
+ "descriptor_dim": 512,
15
+ "grid_spacing": 8
16
+ }
precomputed/points_mr.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86dee258740fe89673a551a7a1611f39b9566c5e6071521f7424ff0a71ecec54
3
+ size 13879
precomputed/points_us.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6b8d0ece568a4552e6f7f88c37d688ffeb2daf8ab559baaab128e826d28e0f2
3
+ size 27575
precomputed/volume_mr.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:feb4d465b3e378693f731a27e0d53544db16ddac9e5e507b5ca073f6531e1bf0
3
+ size 1048704
precomputed/volume_us.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42fb079ee925a7a71d8c76a41c5a69963a67c2b7e585e747d7414c4a17f0162c
3
+ size 1048704
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ crosskey @ git+https://github.com/morozovdd/CrossKEY.git
2
+ gradio>=5.0,<6.0
3
+ plotly>=6.0,<7.0
4
+ scikit-image>=0.24,<1.0
visualization.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plotly 3D visualization for CrossKEY matching results.
2
+
3
+ Builds side-by-side volume isosurfaces with keypoints and match lines.
4
+ MR volume on the left, US volume on the right, offset along the X axis.
5
+ """
6
+
7
+ import numpy as np
8
+ import plotly.graph_objects as go
9
+ from scipy.ndimage import zoom
10
+ from skimage.measure import marching_cubes
11
+
12
+
13
+ def downsample_volume(volume: np.ndarray, target_size: int = 64) -> np.ndarray:
14
+ """Downsample volume to target_size^3 for browser-friendly rendering."""
15
+ factors = [target_size / s for s in volume.shape]
16
+ return zoom(volume, factors, order=1).astype(np.float32)
17
+
18
+
19
+ def scale_points(
20
+ points: np.ndarray,
21
+ padded_shape: tuple,
22
+ volume_shape: tuple,
23
+ ) -> np.ndarray:
24
+ """Scale point coordinates from padded volume space to downsampled volume space."""
25
+ scale = np.array(volume_shape, dtype=float) / np.array(padded_shape, dtype=float)
26
+ return points * scale
27
+
28
+
29
+ def create_isosurface_trace(
30
+ volume: np.ndarray,
31
+ level: float,
32
+ colorscale: str = "Gray",
33
+ opacity: float = 0.15,
34
+ name: str = "",
35
+ offset_x: float = 0.0,
36
+ ) -> go.Mesh3d:
37
+ """Create a Mesh3d trace from a volume isosurface via marching cubes.
38
+
39
+ Uses vertex intensity from the original volume for natural coloring.
40
+ """
41
+ verts, faces, _, _ = marching_cubes(volume, level=level)
42
+ # Sample volume intensity at each vertex for natural coloring
43
+ vi = np.clip(verts.astype(int), 0, np.array(volume.shape) - 1)
44
+ intensities = volume[vi[:, 0], vi[:, 1], vi[:, 2]]
45
+ # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0 so cone points up
46
+ return go.Mesh3d(
47
+ x=verts[:, 1] + offset_x,
48
+ y=verts[:, 2],
49
+ z=-verts[:, 0],
50
+ i=faces[:, 0],
51
+ j=faces[:, 1],
52
+ k=faces[:, 2],
53
+ intensity=intensities,
54
+ colorscale=colorscale,
55
+ opacity=opacity,
56
+ name=name,
57
+ showlegend=True,
58
+ showscale=False,
59
+ )
60
+
61
+
62
+ def create_keypoint_trace(
63
+ points: np.ndarray,
64
+ color: str,
65
+ size: float = 3.0,
66
+ opacity: float = 1.0,
67
+ name: str = "",
68
+ offset_x: float = 0.0,
69
+ ) -> go.Scatter3d:
70
+ """Create Scatter3d markers for keypoints."""
71
+ # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
72
+ return go.Scatter3d(
73
+ x=points[:, 1] + offset_x,
74
+ y=points[:, 2],
75
+ z=-points[:, 0],
76
+ mode="markers",
77
+ marker=dict(size=size, color=color, opacity=opacity),
78
+ name=name,
79
+ showlegend=True,
80
+ )
81
+
82
+
83
+ def create_match_lines(
84
+ src_pts: np.ndarray,
85
+ tgt_pts: np.ndarray,
86
+ color: str,
87
+ width: float = 2.0,
88
+ name: str = "",
89
+ offset_x: float = 0.0,
90
+ ) -> go.Scatter3d:
91
+ """Create lines connecting matched source points to offset target points."""
92
+ # Axis remap: data (0,1,2) -> Plotly (z,x,y); negate axis 0
93
+ lx, ly, lz = [], [], []
94
+ for s, t in zip(src_pts, tgt_pts):
95
+ lx.extend([float(s[1]), float(t[1]) + offset_x, None])
96
+ ly.extend([float(s[2]), float(t[2]), None])
97
+ lz.extend([-float(s[0]), -float(t[0]), None])
98
+ return go.Scatter3d(
99
+ x=lx, y=ly, z=lz,
100
+ mode="lines",
101
+ line=dict(color=color, width=width),
102
+ name=name,
103
+ showlegend=True,
104
+ )
105
+
106
+
107
+ def build_matching_figure(
108
+ volume_mr: np.ndarray,
109
+ volume_us: np.ndarray,
110
+ points_mr: np.ndarray,
111
+ points_us: np.ndarray,
112
+ padded_shape_mr: tuple,
113
+ padded_shape_us: tuple,
114
+ match_pairs: list,
115
+ metrics: dict,
116
+ evaluation_threshold: float = 5.0,
117
+ mr_level: float = 0.3,
118
+ us_level: float = 0.1,
119
+ ) -> go.Figure:
120
+ """Build the full 3D matching visualization."""
121
+ fig = go.Figure()
122
+
123
+ # Scale keypoints to match downsampled volume coordinates
124
+ pts_mr_viz = scale_points(points_mr, padded_shape_mr, volume_mr.shape)
125
+ pts_us_viz = scale_points(points_us, padded_shape_us, volume_us.shape)
126
+
127
+ # Side-by-side offset along Plotly x (= data axis 1)
128
+ gap = volume_mr.shape[1] * 0.3
129
+ offset_x = volume_mr.shape[1] + gap
130
+
131
+ # Volume isosurfaces with natural intensity coloring
132
+ try:
133
+ fig.add_trace(create_isosurface_trace(
134
+ volume_mr, level=mr_level, colorscale="Gray",
135
+ opacity=0.15, name="MR Surface",
136
+ ))
137
+ except ValueError:
138
+ pass
139
+
140
+ try:
141
+ fig.add_trace(create_isosurface_trace(
142
+ volume_us, level=us_level, colorscale="Hot",
143
+ opacity=0.15, name="US Surface", offset_x=offset_x,
144
+ ))
145
+ except ValueError:
146
+ pass
147
+
148
+ # Process matches
149
+ src_indices = [p[0] for p in match_pairs]
150
+ tgt_indices = [p[1] for p in match_pairs]
151
+
152
+ if match_pairs:
153
+ mr_matched = points_mr[src_indices]
154
+ us_matched = points_us[tgt_indices]
155
+ spatial_dist = np.linalg.norm(mr_matched - us_matched, axis=1)
156
+ correct = spatial_dist < evaluation_threshold
157
+
158
+ mr_matched_viz = pts_mr_viz[src_indices]
159
+ us_matched_viz = pts_us_viz[tgt_indices]
160
+
161
+ if correct.any():
162
+ fig.add_trace(create_match_lines(
163
+ mr_matched_viz[correct], us_matched_viz[correct],
164
+ color="rgba(0,200,0,0.6)", width=2,
165
+ name=f"Correct ({correct.sum()})", offset_x=offset_x,
166
+ ))
167
+
168
+ if (~correct).any():
169
+ fig.add_trace(create_match_lines(
170
+ mr_matched_viz[~correct], us_matched_viz[~correct],
171
+ color="rgba(255,0,0,0.3)", width=1,
172
+ name=f"Incorrect ({(~correct).sum()})", offset_x=offset_x,
173
+ ))
174
+
175
+ fig.add_trace(create_keypoint_trace(
176
+ mr_matched_viz, color="royalblue", size=4,
177
+ name=f"MR Matched ({len(mr_matched_viz)})",
178
+ ))
179
+ fig.add_trace(create_keypoint_trace(
180
+ us_matched_viz, color="crimson", size=4,
181
+ name=f"US Matched ({len(us_matched_viz)})", offset_x=offset_x,
182
+ ))
183
+
184
+ # Unmatched keypoints (faded)
185
+ matched_mr_set = set(src_indices)
186
+ matched_us_set = set(tgt_indices)
187
+ unmatched_mr = np.array([i not in matched_mr_set for i in range(len(pts_mr_viz))])
188
+ unmatched_us = np.array([i not in matched_us_set for i in range(len(pts_us_viz))])
189
+
190
+ if unmatched_mr.any():
191
+ fig.add_trace(create_keypoint_trace(
192
+ pts_mr_viz[unmatched_mr], color="royalblue",
193
+ size=1.5, opacity=0.2, name="MR Unmatched",
194
+ ))
195
+ if unmatched_us.any():
196
+ fig.add_trace(create_keypoint_trace(
197
+ pts_us_viz[unmatched_us], color="crimson",
198
+ size=1.5, opacity=0.2, name="US Unmatched", offset_x=offset_x,
199
+ ))
200
+
201
+ # Layout -- no fixed width so Plotly fills the Gradio container
202
+ fig.update_layout(
203
+ scene=dict(
204
+ xaxis=dict(visible=False),
205
+ yaxis=dict(visible=False),
206
+ zaxis=dict(visible=False),
207
+ aspectmode="data",
208
+ camera=dict(
209
+ up=dict(x=0, y=0, z=1),
210
+ eye=dict(x=0, y=-1.8, z=0.3),
211
+ ),
212
+ ),
213
+ height=700,
214
+ margin=dict(l=0, r=0, t=40, b=0),
215
+ legend=dict(
216
+ yanchor="top", y=0.99,
217
+ xanchor="left", x=0.01,
218
+ bgcolor="rgba(0,0,0,0.5)",
219
+ font=dict(color="white"),
220
+ ),
221
+ )
222
+
223
+ return fig