| """Descriptor extraction and matching for CrossKEY HF Space. |
| |
| Provides functions for: |
| 1. Re-running KNN matching with new parameters (CPU, fast) |
| 2. Full inference from uploaded volumes + checkpoint (GPU) |
| """ |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import torch |
|
|
| from src.data.datamodule import DescriptorDataModule |
| from src.model.descriptor import Descriptor |
| from src.model.matcher import KNNMatcher |
| from src.utils.utils import load_nifti |
|
|
| from visualization import downsample_volume |
|
|
|
|
| def load_precomputed(precomputed_dir: str = "precomputed") -> dict: |
| """Load all pre-computed data for the default demo tab. |
| |
| Returns: |
| Dict with keys: descriptors_mr, descriptors_us, points_mr, points_us, |
| volume_mr, volume_us, metadata |
| """ |
| d = Path(precomputed_dir) |
| with open(d / "metadata.json") as f: |
| metadata = json.load(f) |
|
|
| return { |
| "descriptors_mr": torch.load(d / "descriptors_mr.pt", weights_only=True), |
| "descriptors_us": torch.load(d / "descriptors_us.pt", weights_only=True), |
| "points_mr": torch.load(d / "points_mr.pt", weights_only=True).numpy(), |
| "points_us": torch.load(d / "points_us.pt", weights_only=True).numpy(), |
| "volume_mr": np.load(d / "volume_mr.npy"), |
| "volume_us": np.load(d / "volume_us.npy"), |
| "metadata": metadata, |
| } |
|
|
|
|
| def run_matching( |
| descriptors_mr: torch.Tensor, |
| descriptors_us: torch.Tensor, |
| points_mr: np.ndarray, |
| points_us: np.ndarray, |
| ratio_threshold: float = 0.75, |
| mutual: bool = True, |
| metric: str = "euclidean", |
| evaluation_threshold: float = 5.0, |
| ) -> Tuple[List[Tuple[int, int, float]], Dict[str, float]]: |
| """Run KNN matching with given parameters. CPU-only, fast (<1s). |
| |
| Returns: |
| (match_pairs, metrics) -- same format as KNNMatcher.match_and_evaluate() |
| """ |
| matcher = KNNMatcher( |
| k=1, |
| distance_threshold=float("inf"), |
| ratio_threshold=ratio_threshold, |
| mutual=mutual, |
| metric=metric, |
| evaluation_threshold=evaluation_threshold, |
| ) |
| return matcher.match_and_evaluate( |
| descriptors_mr, descriptors_us, points_mr, points_us, |
| ) |
|
|
|
|
| def run_inference( |
| mr_path: str, |
| us_path: str, |
| heatmap_path: str, |
| checkpoint_path: str, |
| batch_size: int = 64, |
| grid_spacing: int = 8, |
| ) -> dict: |
| """Run full inference on uploaded volumes. Requires GPU. |
| |
| Args: |
| mr_path: Path to uploaded MR NIfTI file. |
| us_path: Path to uploaded US NIfTI file. |
| heatmap_path: Path to uploaded heatmap NIfTI file. |
| checkpoint_path: Path to uploaded checkpoint. |
| batch_size: Inference batch size. |
| grid_spacing: Grid spacing for US keypoint generation. |
| |
| Returns: |
| Dict with same keys as load_precomputed(). |
| """ |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| model = Descriptor.load_from_checkpoint(checkpoint_path) |
| model.eval() |
| model.to(device) |
|
|
| |
| dm = DescriptorDataModule( |
| data_dir=".", |
| batch_size=batch_size, |
| num_workers=0, |
| patch_size=(32, 32, 32), |
| grid_spacing=grid_spacing, |
| mr_path=mr_path, |
| us_path=us_path, |
| heatmap_path=heatmap_path, |
| ) |
| dm.setup(stage="test") |
|
|
| |
| all_desc, all_pts, all_mod = [], [], [] |
| with torch.no_grad(): |
| for batch in dm.test_dataloader(): |
| desc = model(batch["patch"].to(device)) |
| all_desc.append(desc.cpu()) |
| all_pts.append(batch["point"].cpu()) |
| all_mod.extend(batch["modality"]) |
|
|
| all_desc = torch.cat(all_desc) |
| all_pts = torch.cat(all_pts) |
| mr_mask = torch.tensor([m == "mr" for m in all_mod]) |
|
|
| |
| mr_vol = load_nifti(mr_path) |
| us_vol = load_nifti(us_path) |
| mr_norm = (mr_vol - mr_vol.min()) / (mr_vol.max() - mr_vol.min() + 1e-8) |
| us_norm = (us_vol - us_vol.min()) / (us_vol.max() - us_vol.min() + 1e-8) |
|
|
| return { |
| "descriptors_mr": all_desc[mr_mask], |
| "descriptors_us": all_desc[~mr_mask], |
| "points_mr": all_pts[mr_mask].numpy(), |
| "points_us": all_pts[~mr_mask].numpy(), |
| "volume_mr": downsample_volume(mr_norm), |
| "volume_us": downsample_volume(us_norm), |
| "metadata": { |
| "padded_shape_mr": list(dm._mr_volume.shape), |
| "padded_shape_us": list(dm._us_volume.shape), |
| }, |
| } |
|
|