File size: 4,585 Bytes
ffbfad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Descriptor extraction and matching for CrossKEY HF Space.

Provides functions for:
1. Re-running KNN matching with new parameters (CPU, fast)
2. Full inference from uploaded volumes + checkpoint (GPU)
"""

import json
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch

from src.data.datamodule import DescriptorDataModule
from src.model.descriptor import Descriptor
from src.model.matcher import KNNMatcher
from src.utils.utils import load_nifti

from visualization import downsample_volume


def load_precomputed(precomputed_dir: str = "precomputed") -> dict:
    """Load all pre-computed data for the default demo tab.

    Returns:
        Dict with keys: descriptors_mr, descriptors_us, points_mr, points_us,
                        volume_mr, volume_us, metadata
    """
    d = Path(precomputed_dir)
    with open(d / "metadata.json") as f:
        metadata = json.load(f)

    return {
        "descriptors_mr": torch.load(d / "descriptors_mr.pt", weights_only=True),
        "descriptors_us": torch.load(d / "descriptors_us.pt", weights_only=True),
        "points_mr": torch.load(d / "points_mr.pt", weights_only=True).numpy(),
        "points_us": torch.load(d / "points_us.pt", weights_only=True).numpy(),
        "volume_mr": np.load(d / "volume_mr.npy"),
        "volume_us": np.load(d / "volume_us.npy"),
        "metadata": metadata,
    }


def run_matching(
    descriptors_mr: torch.Tensor,
    descriptors_us: torch.Tensor,
    points_mr: np.ndarray,
    points_us: np.ndarray,
    ratio_threshold: float = 0.75,
    mutual: bool = True,
    metric: str = "euclidean",
    evaluation_threshold: float = 5.0,
) -> Tuple[List[Tuple[int, int, float]], Dict[str, float]]:
    """Run KNN matching with given parameters. CPU-only, fast (<1s).

    Returns:
        (match_pairs, metrics) -- same format as KNNMatcher.match_and_evaluate()
    """
    matcher = KNNMatcher(
        k=1,
        distance_threshold=float("inf"),
        ratio_threshold=ratio_threshold,
        mutual=mutual,
        metric=metric,
        evaluation_threshold=evaluation_threshold,
    )
    return matcher.match_and_evaluate(
        descriptors_mr, descriptors_us, points_mr, points_us,
    )


def run_inference(
    mr_path: str,
    us_path: str,
    heatmap_path: str,
    checkpoint_path: str,
    batch_size: int = 64,
    grid_spacing: int = 8,
) -> dict:
    """Run full inference on uploaded volumes. Requires GPU.

    Args:
        mr_path: Path to uploaded MR NIfTI file.
        us_path: Path to uploaded US NIfTI file.
        heatmap_path: Path to uploaded heatmap NIfTI file.
        checkpoint_path: Path to uploaded checkpoint.
        batch_size: Inference batch size.
        grid_spacing: Grid spacing for US keypoint generation.

    Returns:
        Dict with same keys as load_precomputed().
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model
    model = Descriptor.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.to(device)

    # Create datamodule with custom paths
    dm = DescriptorDataModule(
        data_dir=".",  # Not used when paths are specified
        batch_size=batch_size,
        num_workers=0,
        patch_size=(32, 32, 32),
        grid_spacing=grid_spacing,
        mr_path=mr_path,
        us_path=us_path,
        heatmap_path=heatmap_path,
    )
    dm.setup(stage="test")

    # Extract descriptors
    all_desc, all_pts, all_mod = [], [], []
    with torch.no_grad():
        for batch in dm.test_dataloader():
            desc = model(batch["patch"].to(device))
            all_desc.append(desc.cpu())
            all_pts.append(batch["point"].cpu())
            all_mod.extend(batch["modality"])

    all_desc = torch.cat(all_desc)
    all_pts = torch.cat(all_pts)
    mr_mask = torch.tensor([m == "mr" for m in all_mod])

    # Downsample volumes for rendering
    mr_vol = load_nifti(mr_path)
    us_vol = load_nifti(us_path)
    mr_norm = (mr_vol - mr_vol.min()) / (mr_vol.max() - mr_vol.min() + 1e-8)
    us_norm = (us_vol - us_vol.min()) / (us_vol.max() - us_vol.min() + 1e-8)

    return {
        "descriptors_mr": all_desc[mr_mask],
        "descriptors_us": all_desc[~mr_mask],
        "points_mr": all_pts[mr_mask].numpy(),
        "points_us": all_pts[~mr_mask].numpy(),
        "volume_mr": downsample_volume(mr_norm),
        "volume_us": downsample_volume(us_norm),
        "metadata": {
            "padded_shape_mr": list(dm._mr_volume.shape),
            "padded_shape_us": list(dm._us_volume.shape),
        },
    }