AnikS22 commited on
Commit
6dd4c34
·
verified ·
1 Parent(s): 72f2556

Upload src/ensemble.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/ensemble.py +236 -0
src/ensemble.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-time augmentation (D4 dihedral group) and model ensemble averaging.
3
+
4
+ D4 TTA: 4 rotations x 2 reflections = 8 geometric views
5
+ + 2 intensity variants = 10 total forward passes.
6
+ Gold beads are rotationally invariant — D4 TTA is maximally effective.
7
+ Expected F1 gain: +1-3% over single forward pass.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from typing import List, Optional
14
+
15
+ from src.model import ImmunogoldCenterNet
16
+
17
+
18
+ def d4_tta_predict(
19
+ model: ImmunogoldCenterNet,
20
+ image: np.ndarray,
21
+ device: torch.device = torch.device("cpu"),
22
+ ) -> tuple:
23
+ """
24
+ Test-time augmentation over D4 dihedral group + intensity variants.
25
+
26
+ Args:
27
+ model: trained CenterNet model
28
+ image: (H, W) uint8 preprocessed image
29
+ device: torch device
30
+
31
+ Returns:
32
+ averaged_heatmap: (2, H/2, W/2) numpy array
33
+ averaged_offsets: (2, H/2, W/2) numpy array
34
+ """
35
+ model.eval()
36
+ heatmaps = []
37
+ offsets_list = []
38
+
39
+ # Ensure image dimensions are divisible by 32 for the encoder
40
+ h, w = image.shape[:2]
41
+ pad_h = (32 - h % 32) % 32
42
+ pad_w = (32 - w % 32) % 32
43
+
44
+ def _forward(img_np):
45
+ """Run model on numpy image, return heatmap and offsets."""
46
+ # Pad to multiple of 32
47
+ if pad_h > 0 or pad_w > 0:
48
+ img_np = np.pad(img_np, ((0, pad_h), (0, pad_w)), mode="reflect")
49
+
50
+ tensor = (
51
+ torch.from_numpy(img_np)
52
+ .float()
53
+ .unsqueeze(0)
54
+ .unsqueeze(0) # (1, 1, H, W)
55
+ / 255.0
56
+ ).to(device)
57
+
58
+ with torch.no_grad():
59
+ hm, off = model(tensor)
60
+
61
+ hm = hm.squeeze(0).cpu().numpy() # (2, H/2, W/2)
62
+ off = off.squeeze(0).cpu().numpy() # (2, H/2, W/2)
63
+
64
+ # Remove padding from output
65
+ hm_h = h // 2
66
+ hm_w = w // 2
67
+ return hm[:, :hm_h, :hm_w], off[:, :hm_h, :hm_w]
68
+
69
+ # D4 group: 4 rotations x 2 reflections = 8 geometric views
70
+ for k in range(4):
71
+ for flip in [False, True]:
72
+ aug = np.rot90(image, k).copy()
73
+ if flip:
74
+ aug = np.fliplr(aug).copy()
75
+
76
+ hm, off = _forward(aug)
77
+
78
+ # Inverse transforms on heatmap and offsets
79
+ if flip:
80
+ hm = np.flip(hm, axis=2).copy() # flip W axis
81
+ off = np.flip(off, axis=2).copy()
82
+ off[0] = -off[0] # negate x offset for horizontal flip
83
+
84
+ if k > 0:
85
+ hm = np.rot90(hm, -k, axes=(1, 2)).copy()
86
+ off = np.rot90(off, -k, axes=(1, 2)).copy()
87
+ # Rotate offset vectors
88
+ if k == 1: # 90° CCW undo
89
+ off = np.stack([-off[1], off[0]], axis=0)
90
+ elif k == 2: # 180°
91
+ off = np.stack([-off[0], -off[1]], axis=0)
92
+ elif k == 3: # 270° CCW undo
93
+ off = np.stack([off[1], -off[0]], axis=0)
94
+
95
+ heatmaps.append(hm)
96
+ offsets_list.append(off)
97
+
98
+ # 2 intensity variants
99
+ for factor in [0.9, 1.1]:
100
+ aug = np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)
101
+ hm, off = _forward(aug)
102
+ heatmaps.append(hm)
103
+ offsets_list.append(off)
104
+
105
+ # Average all views
106
+ avg_heatmap = np.mean(heatmaps, axis=0)
107
+ avg_offsets = np.mean(offsets_list, axis=0)
108
+
109
+ return avg_heatmap, avg_offsets
110
+
111
+
112
+ def ensemble_predict(
113
+ models: List[ImmunogoldCenterNet],
114
+ image: np.ndarray,
115
+ device: torch.device = torch.device("cpu"),
116
+ use_tta: bool = True,
117
+ ) -> tuple:
118
+ """
119
+ Ensemble prediction: average heatmaps from N models.
120
+
121
+ Args:
122
+ models: list of trained models (e.g., 5 seeds x 3 snapshots = 15)
123
+ image: (H, W) uint8 preprocessed image
124
+ device: torch device
125
+ use_tta: whether to apply D4 TTA per model
126
+
127
+ Returns:
128
+ averaged_heatmap: (2, H/2, W/2) numpy array
129
+ averaged_offsets: (2, H/2, W/2) numpy array
130
+ """
131
+ all_heatmaps = []
132
+ all_offsets = []
133
+
134
+ for model in models:
135
+ model.eval()
136
+ model.to(device)
137
+
138
+ if use_tta:
139
+ hm, off = d4_tta_predict(model, image, device)
140
+ else:
141
+ h, w = image.shape[:2]
142
+ pad_h = (32 - h % 32) % 32
143
+ pad_w = (32 - w % 32) % 32
144
+ img_padded = np.pad(image, ((0, pad_h), (0, pad_w)), mode="reflect")
145
+
146
+ tensor = (
147
+ torch.from_numpy(img_padded)
148
+ .float()
149
+ .unsqueeze(0)
150
+ .unsqueeze(0)
151
+ / 255.0
152
+ ).to(device)
153
+
154
+ with torch.no_grad():
155
+ hm_t, off_t = model(tensor)
156
+
157
+ hm = hm_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
158
+ off = off_t.squeeze(0).cpu().numpy()[:, : h // 2, : w // 2]
159
+
160
+ all_heatmaps.append(hm)
161
+ all_offsets.append(off)
162
+
163
+ return np.mean(all_heatmaps, axis=0), np.mean(all_offsets, axis=0)
164
+
165
+
166
+ def sliding_window_inference(
167
+ model: ImmunogoldCenterNet,
168
+ image: np.ndarray,
169
+ patch_size: int = 512,
170
+ overlap: int = 128,
171
+ device: torch.device = torch.device("cpu"),
172
+ ) -> tuple:
173
+ """
174
+ Full-image inference via sliding window with overlap stitching.
175
+
176
+ Tiles the image into overlapping patches, runs the model on each,
177
+ and stitches heatmaps using max in overlap regions.
178
+
179
+ Args:
180
+ model: trained model
181
+ image: (H, W) uint8 preprocessed image
182
+ patch_size: tile size
183
+ overlap: overlap between tiles
184
+ device: torch device
185
+
186
+ Returns:
187
+ heatmap: (2, H/2, W/2) numpy array
188
+ offsets: (2, H/2, W/2) numpy array
189
+ """
190
+ model.eval()
191
+ h, w = image.shape[:2]
192
+ stride_step = patch_size - overlap
193
+
194
+ # Output dimensions at model stride
195
+ out_h = h // 2
196
+ out_w = w // 2
197
+ out_patch = patch_size // 2
198
+
199
+ heatmap = np.zeros((2, out_h, out_w), dtype=np.float32)
200
+ offsets = np.zeros((2, out_h, out_w), dtype=np.float32)
201
+ count = np.zeros((out_h, out_w), dtype=np.float32)
202
+
203
+ for y0 in range(0, h - patch_size + 1, stride_step):
204
+ for x0 in range(0, w - patch_size + 1, stride_step):
205
+ patch = image[y0 : y0 + patch_size, x0 : x0 + patch_size]
206
+ tensor = (
207
+ torch.from_numpy(patch)
208
+ .float()
209
+ .unsqueeze(0)
210
+ .unsqueeze(0)
211
+ / 255.0
212
+ ).to(device)
213
+
214
+ with torch.no_grad():
215
+ hm, off = model(tensor)
216
+
217
+ hm_np = hm.squeeze(0).cpu().numpy()
218
+ off_np = off.squeeze(0).cpu().numpy()
219
+
220
+ # Output coordinates
221
+ oy0 = y0 // 2
222
+ ox0 = x0 // 2
223
+
224
+ # Max-stitch heatmap, average-stitch offsets
225
+ heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] = np.maximum(
226
+ heatmap[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch],
227
+ hm_np,
228
+ )
229
+ offsets[:, oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += off_np
230
+ count[oy0 : oy0 + out_patch, ox0 : ox0 + out_patch] += 1
231
+
232
+ # Average offsets where counted
233
+ count = np.maximum(count, 1)
234
+ offsets /= count[np.newaxis, :, :]
235
+
236
+ return heatmap, offsets