File size: 8,318 Bytes
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a62959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a62959
 
 
 
 
 
 
 
86c24cb
 
 
2a62959
86c24cb
 
 
 
 
 
 
 
2a62959
 
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a62959
 
 
 
 
86c24cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
Test-time augmentation (D4 dihedral group) and model ensemble averaging.

D4 TTA: 4 rotations x 2 reflections = 8 geometric views
+ 2 intensity variants = 10 total forward passes.
Gold beads are rotationally invariant — D4 TTA is maximally effective.
Expected F1 gain: +1-3% over single forward pass.
"""

import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Optional

from src.model import ImmunogoldCenterNet


def d4_tta_predict(
    model: ImmunogoldCenterNet,
    image: np.ndarray,
    device: torch.device = torch.device("cpu"),
) -> tuple:
    """
    Test-time augmentation over D4 dihedral group + intensity variants.

    Args:
        model: trained CenterNet model
        image: (H, W) uint8 preprocessed image
        device: torch device

    Returns:
        averaged_heatmap: (2, H/2, W/2) numpy array
        averaged_offsets: (2, H/2, W/2) numpy array
    """
    model.eval()
    heatmaps = []
    offsets_list = []

    # Ensure image dimensions are divisible by 32 for the encoder
    h, w = image.shape[:2]
    pad_h = (32 - h % 32) % 32
    pad_w = (32 - w % 32) % 32

    def _forward(img_np):
        """Run model on numpy image, return heatmap and offsets."""
        # Pad to multiple of 32
        if pad_h > 0 or pad_w > 0:
            img_np = np.pad(img_np, ((0, pad_h), (0, pad_w)), mode="reflect")

        tensor = (
            torch.from_numpy(img_np)
            .float()
            .unsqueeze(0)
            .unsqueeze(0)  # (1, 1, H, W)
            / 255.0
        ).to(device)

        with torch.no_grad():
            hm, off = model(tensor)

        hm = hm.squeeze(0).cpu().numpy()   # (2, H/2, W/2)
        off = off.squeeze(0).cpu().numpy()  # (2, H/2, W/2)

        # Remove padding from output
        hm_h = h // 2
        hm_w = w // 2
        return hm[:, :hm_h, :hm_w], off[:, :hm_h, :hm_w]

    # D4 group: 4 rotations x 2 reflections = 8 geometric views
    for k in range(4):
        for flip in [False, True]:
            aug = np.rot90(image, k).copy()
            if flip:
                aug = np.fliplr(aug).copy()

            hm, off = _forward(aug)

            # Inverse transforms on heatmap and offsets
            if flip:
                hm = np.flip(hm, axis=2).copy()   # flip W axis
                off = np.flip(off, axis=2).copy()
                off[0] = -off[0]  # negate x offset for horizontal flip

            if k > 0:
                hm = np.rot90(hm, -k, axes=(1, 2)).copy()
                off = np.rot90(off, -k, axes=(1, 2)).copy()
                # Rotate offset vectors
                if k == 1:  # 90° CCW undo
                    off = np.stack([-off[1], off[0]], axis=0)
                elif k == 2:  # 180°
                    off = np.stack([-off[0], -off[1]], axis=0)
                elif k == 3:  # 270° CCW undo
                    off = np.stack([off[1], -off[0]], axis=0)

            heatmaps.append(hm)
            offsets_list.append(off)

    # 2 intensity variants
    for factor in [0.9, 1.1]:
        aug = np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)
        hm, off = _forward(aug)
        heatmaps.append(hm)
        offsets_list.append(off)

    # Average all views
    avg_heatmap = np.mean(heatmaps, axis=0)
    avg_offsets = np.mean(offsets_list, axis=0)

    return avg_heatmap, avg_offsets


def ensemble_predict(
    models: List[ImmunogoldCenterNet],
    image: np.ndarray,
    device: torch.device = torch.device("cpu"),
    use_tta: bool = True,
) -> tuple:
    """
    Ensemble prediction: average heatmaps from N models.

    Args:
        models: list of trained models (e.g., 5 seeds x 3 snapshots = 15)
        image: (H, W) uint8 preprocessed image
        device: torch device
        use_tta: whether to apply D4 TTA per model

    Returns:
        averaged_heatmap: (2, H/2, W/2) numpy array
        averaged_offsets: (2, H/2, W/2) numpy array
    """
    all_heatmaps = []
    all_offsets = []

    for model in models:
        model.eval()
        model.to(device)

        if use_tta:
            hm, off = d4_tta_predict(model, image, device)
        else:
            h, w = image.shape[:2]
            pad_h = (32 - h % 32) % 32
            pad_w = (32 - w % 32) % 32
            img_padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")

            tensor = (
                torch.from_numpy(img_padded)
                .float()
                .unsqueeze(0)
                .unsqueeze(0)
                / 255.0
            ).to(device)

            with torch.no_grad():
                hm_t, off_t = model(tensor)

            hm = hm_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
            off = off_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]

        all_heatmaps.append(hm)
        all_offsets.append(off)

    return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)


def _tile_origins(axis_len: int, patch: int, stride_step: int) -> list:
    """
    Starting indices for sliding windows along one axis so the last window
    flush-aligns with the far edge. Plain range(0, n-patch+1, step) misses
    the bottom/right of most image sizes (e.g. 2048 with patch 512, step 384),
    leaving heatmap strips at zero.
    """
    if axis_len <= patch:
        return [0]
    last = axis_len - patch
    starts = list(range(0, last + 1, stride_step))
    if starts[-1] != last:
        starts.append(last)
    return starts


def sliding_window_inference(
    model: ImmunogoldCenterNet,
    image: np.ndarray,
    patch_size: int = 512,
    overlap: int = 128,
    device: torch.device = torch.device("cpu"),
) -> tuple:
    """
    Full-image inference via sliding window with overlap stitching.

    Tiles the image into overlapping patches, runs the model on each,
    and stitches heatmaps using max in overlap regions.

    Args:
        model: trained model
        image: (H, W) uint8 preprocessed image
        patch_size: tile size
        overlap: overlap between tiles
        device: torch device

    Returns:
        heatmap: (2, H/2, W/2) numpy array
        offsets: (2, H/2, W/2) numpy array
    """
    model.eval()
    orig_h, orig_w = image.shape[:2]
    # Pad bottom/right so each dim >= patch_size; otherwise range() for tiles is empty
    # and heatmaps stay all zeros (looks like a "broken" heatmap in the UI).
    pad_h = max(0, patch_size - orig_h)
    pad_w = max(0, patch_size - orig_w)
    if pad_h > 0 or pad_w > 0:
        image = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")

    h, w = image.shape[:2]
    stride_step = patch_size - overlap

    # Output dimensions at model stride (padded image)
    out_h = h // 2
    out_w = w // 2
    out_patch = patch_size // 2

    heatmap = np.zeros((2, out_h, out_w), dtype=np.float32)
    offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
    count = np.zeros((out_h, out_w), dtype=np.float32)

    for y0 in _tile_origins(h, patch_size, stride_step):
        for x0 in _tile_origins(w, patch_size, stride_step):
            patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
            tensor = (
                torch.from_numpy(patch)
                .float()
                .unsqueeze(0)
                .unsqueeze(0)
                / 255.0
            ).to(device)

            with torch.no_grad():
                hm, off = model(tensor)

            hm_np = hm.squeeze(0).cpu().numpy()
            off_np = off.squeeze(0).cpu().numpy()

            # Output coordinates
            oy0 = y0 // 2
            ox0 = x0 // 2

            # Max-stitch heatmap, average-stitch offsets
            heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] = np.maximum(
                heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch],
                hm_np,
            )
            offsets[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += off_np
            count[oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += 1

    # Average offsets where counted
    count = np.maximum(count, 1)
    offsets /= count[np.newaxis, :, :]

    # Crop back to original (pre-pad) spatial extent in heatmap space
    crop_h, crop_w = orig_h // 2, orig_w // 2
    heatmap = heatmap[:, :crop_h, :crop_w]
    offsets = offsets[:, :crop_h, :crop_w]

    return heatmap, offsets