| """CrossKEY HuggingFace Space -- Interactive 3D Keypoint Matching Demo. |
| |
| Two-tab Gradio app: |
| Tab 1 (Explore): Pre-computed results with adjustable matching parameters. |
| Tab 2 (Your Data): Upload volumes + checkpoint for live inference. |
| """ |
|
|
| import logging |
| import os |
| import sys |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| import gradio as gr |
| import numpy as np |
|
|
| from inference import load_precomputed, run_inference, run_matching |
| from visualization import build_matching_figure |
|
|
| |
| try: |
| import spaces |
| gpu_decorator = spaces.GPU |
| except ImportError: |
| gpu_decorator = lambda fn: fn |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") |
| logger = logging.getLogger("crosskey.app") |
|
|
| |
| logger.info("Loading pre-computed demo data...") |
| DEMO_DATA = load_precomputed( |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "precomputed") |
| ) |
| logger.info( |
| "Loaded: %d MR descriptors, %d US descriptors", |
| len(DEMO_DATA["descriptors_mr"]), |
| len(DEMO_DATA["descriptors_us"]), |
| ) |
|
|
|
|
| def update_demo( |
| ratio_threshold: float, |
| evaluation_threshold: float, |
| mutual: bool, |
| metric: str, |
| ) -> tuple: |
| """Re-run matching with new parameters and rebuild the figure.""" |
| match_pairs, metrics = run_matching( |
| DEMO_DATA["descriptors_mr"], |
| DEMO_DATA["descriptors_us"], |
| DEMO_DATA["points_mr"], |
| DEMO_DATA["points_us"], |
| ratio_threshold=ratio_threshold, |
| mutual=mutual, |
| metric=metric, |
| evaluation_threshold=evaluation_threshold, |
| ) |
|
|
| fig = build_matching_figure( |
| volume_mr=DEMO_DATA["volume_mr"], |
| volume_us=DEMO_DATA["volume_us"], |
| points_mr=DEMO_DATA["points_mr"], |
| points_us=DEMO_DATA["points_us"], |
| padded_shape_mr=tuple(DEMO_DATA["metadata"]["padded_shape_mr"]), |
| padded_shape_us=tuple(DEMO_DATA["metadata"]["padded_shape_us"]), |
| match_pairs=match_pairs, |
| metrics=metrics, |
| evaluation_threshold=evaluation_threshold, |
| ) |
|
|
| return ( |
| fig, |
| metrics['num_matches'], |
| metrics['num_correct'], |
| round(metrics['precision'], 1), |
| round(metrics['matching_score'], 4), |
| ) |
|
|
|
|
| @gpu_decorator |
| def run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file): |
| """Run inference on uploaded data. Uses ZeroGPU on HF Spaces.""" |
| if any(f is None for f in [mr_file, us_file, heatmap_file, ckpt_file]): |
| raise gr.Error("Please upload all four files: MR volume, US volume, heatmap, and checkpoint.") |
|
|
| logger.info("Running inference on uploaded data...") |
| data = run_inference( |
| mr_path=mr_file.name, |
| us_path=us_file.name, |
| heatmap_path=heatmap_file.name, |
| checkpoint_path=ckpt_file.name, |
| ) |
| return data |
|
|
|
|
| def update_custom( |
| data: dict, |
| ratio_threshold: float, |
| evaluation_threshold: float, |
| mutual: bool, |
| metric: str, |
| ) -> tuple: |
| """Re-run matching on custom data with new parameters.""" |
| if data is None: |
| raise gr.Error("Run inference first.") |
|
|
| match_pairs, metrics = run_matching( |
| data["descriptors_mr"], |
| data["descriptors_us"], |
| data["points_mr"], |
| data["points_us"], |
| ratio_threshold=ratio_threshold, |
| mutual=mutual, |
| metric=metric, |
| evaluation_threshold=evaluation_threshold, |
| ) |
|
|
| fig = build_matching_figure( |
| volume_mr=data["volume_mr"], |
| volume_us=data["volume_us"], |
| points_mr=data["points_mr"], |
| points_us=data["points_us"], |
| padded_shape_mr=tuple(data["metadata"]["padded_shape_mr"]), |
| padded_shape_us=tuple(data["metadata"]["padded_shape_us"]), |
| match_pairs=match_pairs, |
| metrics=metrics, |
| evaluation_threshold=evaluation_threshold, |
| ) |
|
|
| return ( |
| fig, |
| metrics['num_matches'], |
| metrics['num_correct'], |
| round(metrics['precision'], 1), |
| round(metrics['matching_score'], 4), |
| ) |
|
|
|
|
| |
|
|
| with gr.Blocks( |
| title="CrossKEY -- 3D Cross-modal Keypoint Matching", |
| theme=gr.themes.Soft(), |
| css="footer {display: none !important;}", |
| ) as demo: |
| gr.Markdown( |
| "# CrossKEY\n" |
| "**3D Cross-modal Keypoint Descriptor for MR-US Matching and Registration**" |
| ) |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("Explore"): |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=260): |
| gr.Markdown("### Matching Parameters") |
| demo_ratio = gr.Slider( |
| minimum=0.5, maximum=1.0, value=0.75, step=0.05, |
| label="Ratio Threshold", |
| ) |
| demo_eval_thresh = gr.Slider( |
| minimum=1.0, maximum=10.0, value=5.0, step=0.5, |
| label="Evaluation Threshold (mm)", |
| ) |
| demo_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor") |
| demo_metric = gr.Dropdown( |
| choices=["euclidean", "cosine"], value="euclidean", |
| label="Distance Metric", |
| ) |
|
|
| gr.Markdown("### Results") |
| with gr.Group(): |
| with gr.Row(): |
| demo_n_matches = gr.Number(label="Matches", interactive=False) |
| demo_n_correct = gr.Number(label="Correct", interactive=False) |
| with gr.Row(): |
| demo_precision = gr.Number(label="Precision (%)", interactive=False) |
| demo_match_score = gr.Number(label="Match Score", interactive=False) |
|
|
| with gr.Column(scale=3): |
| demo_plot = gr.Plot(label="3D Matching Visualization") |
|
|
| demo_inputs = [demo_ratio, demo_eval_thresh, demo_mutual, demo_metric] |
| demo_outputs = [demo_plot, demo_n_matches, demo_n_correct, demo_precision, demo_match_score] |
|
|
| |
| for inp in demo_inputs: |
| inp.change(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs) |
|
|
| |
| demo.load(fn=update_demo, inputs=demo_inputs, outputs=demo_outputs) |
|
|
| |
| with gr.Tab("Your Data"): |
| gr.Markdown( |
| "Upload your own MR volume, US volume, heatmap, and a trained CrossKEY checkpoint.\n\n" |
| "Inference runs on GPU and may take 30-60 seconds." |
| ) |
|
|
| with gr.Row(): |
| custom_mr = gr.File(label="MR Volume (.nii.gz)", file_types=[".nii.gz"]) |
| custom_us = gr.File(label="US Volume (.nii.gz)", file_types=[".nii.gz"]) |
| with gr.Row(): |
| custom_heatmap = gr.File(label="Heatmap (.nii.gz)", file_types=[".nii.gz"]) |
| custom_ckpt = gr.File(label="Checkpoint (.ckpt)", file_types=[".ckpt"]) |
|
|
| custom_run_btn = gr.Button("Run Inference (GPU)", variant="primary") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=260): |
| gr.Markdown("### Matching Parameters") |
| custom_ratio = gr.Slider( |
| minimum=0.5, maximum=1.0, value=0.75, step=0.05, |
| label="Ratio Threshold", |
| ) |
| custom_eval_thresh = gr.Slider( |
| minimum=1.0, maximum=10.0, value=5.0, step=0.5, |
| label="Evaluation Threshold (mm)", |
| ) |
| custom_mutual = gr.Checkbox(value=True, label="Mutual Nearest Neighbor") |
| custom_metric = gr.Dropdown( |
| choices=["euclidean", "cosine"], value="euclidean", |
| label="Distance Metric", |
| ) |
|
|
| gr.Markdown("### Results") |
| with gr.Group(): |
| with gr.Row(): |
| custom_n_matches = gr.Number(label="Matches", interactive=False) |
| custom_n_correct = gr.Number(label="Correct", interactive=False) |
| with gr.Row(): |
| custom_precision = gr.Number(label="Precision (%)", interactive=False) |
| custom_match_score = gr.Number(label="Match Score", interactive=False) |
|
|
| with gr.Column(scale=3): |
| custom_plot = gr.Plot(label="3D Matching Visualization") |
|
|
| |
| custom_data_state = gr.State(value=None) |
|
|
| custom_param_inputs = [custom_ratio, custom_eval_thresh, custom_mutual, custom_metric] |
| custom_outputs = [custom_plot, custom_n_matches, custom_n_correct, custom_precision, custom_match_score] |
|
|
| |
| def infer_and_display(mr_file, us_file, heatmap_file, ckpt_file, ratio, eval_thresh, mutual, metric): |
| data = run_custom_inference(mr_file, us_file, heatmap_file, ckpt_file) |
| fig, n_m, n_c, prec, ms = update_custom(data, ratio, eval_thresh, mutual, metric) |
| return data, fig, n_m, n_c, prec, ms |
|
|
| custom_run_btn.click( |
| fn=infer_and_display, |
| inputs=[custom_mr, custom_us, custom_heatmap, custom_ckpt] + custom_param_inputs, |
| outputs=[custom_data_state] + custom_outputs, |
| ) |
|
|
| |
| for inp in custom_param_inputs: |
| inp.change( |
| fn=update_custom, |
| inputs=[custom_data_state] + custom_param_inputs, |
| outputs=custom_outputs, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|