CrossKEY / inference.py
morozovdd's picture
Initial deploy: CrossKEY interactive 3D matching demo
ffbfad7
"""Descriptor extraction and matching for CrossKEY HF Space.
Provides functions for:
1. Re-running KNN matching with new parameters (CPU, fast)
2. Full inference from uploaded volumes + checkpoint (GPU)
"""
import json
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from src.data.datamodule import DescriptorDataModule
from src.model.descriptor import Descriptor
from src.model.matcher import KNNMatcher
from src.utils.utils import load_nifti
from visualization import downsample_volume
def load_precomputed(precomputed_dir: str = "precomputed") -> dict:
"""Load all pre-computed data for the default demo tab.
Returns:
Dict with keys: descriptors_mr, descriptors_us, points_mr, points_us,
volume_mr, volume_us, metadata
"""
d = Path(precomputed_dir)
with open(d / "metadata.json") as f:
metadata = json.load(f)
return {
"descriptors_mr": torch.load(d / "descriptors_mr.pt", weights_only=True),
"descriptors_us": torch.load(d / "descriptors_us.pt", weights_only=True),
"points_mr": torch.load(d / "points_mr.pt", weights_only=True).numpy(),
"points_us": torch.load(d / "points_us.pt", weights_only=True).numpy(),
"volume_mr": np.load(d / "volume_mr.npy"),
"volume_us": np.load(d / "volume_us.npy"),
"metadata": metadata,
}
def run_matching(
descriptors_mr: torch.Tensor,
descriptors_us: torch.Tensor,
points_mr: np.ndarray,
points_us: np.ndarray,
ratio_threshold: float = 0.75,
mutual: bool = True,
metric: str = "euclidean",
evaluation_threshold: float = 5.0,
) -> Tuple[List[Tuple[int, int, float]], Dict[str, float]]:
"""Run KNN matching with given parameters. CPU-only, fast (<1s).
Returns:
(match_pairs, metrics) -- same format as KNNMatcher.match_and_evaluate()
"""
matcher = KNNMatcher(
k=1,
distance_threshold=float("inf"),
ratio_threshold=ratio_threshold,
mutual=mutual,
metric=metric,
evaluation_threshold=evaluation_threshold,
)
return matcher.match_and_evaluate(
descriptors_mr, descriptors_us, points_mr, points_us,
)
def run_inference(
mr_path: str,
us_path: str,
heatmap_path: str,
checkpoint_path: str,
batch_size: int = 64,
grid_spacing: int = 8,
) -> dict:
"""Run full inference on uploaded volumes. Requires GPU.
Args:
mr_path: Path to uploaded MR NIfTI file.
us_path: Path to uploaded US NIfTI file.
heatmap_path: Path to uploaded heatmap NIfTI file.
checkpoint_path: Path to uploaded checkpoint.
batch_size: Inference batch size.
grid_spacing: Grid spacing for US keypoint generation.
Returns:
Dict with same keys as load_precomputed().
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = Descriptor.load_from_checkpoint(checkpoint_path)
model.eval()
model.to(device)
# Create datamodule with custom paths
dm = DescriptorDataModule(
data_dir=".", # Not used when paths are specified
batch_size=batch_size,
num_workers=0,
patch_size=(32, 32, 32),
grid_spacing=grid_spacing,
mr_path=mr_path,
us_path=us_path,
heatmap_path=heatmap_path,
)
dm.setup(stage="test")
# Extract descriptors
all_desc, all_pts, all_mod = [], [], []
with torch.no_grad():
for batch in dm.test_dataloader():
desc = model(batch["patch"].to(device))
all_desc.append(desc.cpu())
all_pts.append(batch["point"].cpu())
all_mod.extend(batch["modality"])
all_desc = torch.cat(all_desc)
all_pts = torch.cat(all_pts)
mr_mask = torch.tensor([m == "mr" for m in all_mod])
# Downsample volumes for rendering
mr_vol = load_nifti(mr_path)
us_vol = load_nifti(us_path)
mr_norm = (mr_vol - mr_vol.min()) / (mr_vol.max() - mr_vol.min() + 1e-8)
us_norm = (us_vol - us_vol.min()) / (us_vol.max() - us_vol.min() + 1e-8)
return {
"descriptors_mr": all_desc[mr_mask],
"descriptors_us": all_desc[~mr_mask],
"points_mr": all_pts[mr_mask].numpy(),
"points_us": all_pts[~mr_mask].numpy(),
"volume_mr": downsample_volume(mr_norm),
"volume_us": downsample_volume(us_norm),
"metadata": {
"padded_shape_mr": list(dm._mr_volume.shape),
"padded_shape_us": list(dm._us_volume.shape),
},
}