File size: 14,907 Bytes
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a62959
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
2a62959
 
86c24cb
 
 
 
 
2a62959
86c24cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
"""
CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.

Architecture:
    Input:   1ch grayscale, variable size (padded to multiple of 32)
    Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
    Neck:    BiFPN (2 rounds, 128ch)
    Decoder: Transposed conv β†’ stride-2 output
    Heads:   Heatmap (2ch sigmoid), Offset (2ch)
    Output:  Stride-2 maps β†’ (H/2, W/2) resolution

Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
collapses to 1px in feature space β€” insufficient for detection.
At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import List, Optional


# ---------------------------------------------------------------------------
# BiFPN: Bidirectional Feature Pyramid Network
# ---------------------------------------------------------------------------

class DepthwiseSeparableConv(nn.Module):
    """Depthwise separable convolution as used in BiFPN."""

    def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
                 stride: int = 1, padding: int = 1):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_ch, in_ch, kernel_size, stride=stride,
            padding=padding, groups=in_ch, bias=False,
        )
        self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.bn(self.pointwise(self.depthwise(x))))


class BiFPNFusionNode(nn.Module):
    """
    Single BiFPN fusion node with fast normalized weighted fusion.

    w_normalized = relu(w) / (sum(relu(w)) + eps)
    output = conv(sum(w_i * input_i))
    """

    def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
        super().__init__()
        self.eps = eps
        # Learnable fusion weights
        self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
        self.conv = DepthwiseSeparableConv(channels, channels)

    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        # Fast normalized fusion
        w = F.relu(self.weights)
        w_norm = w / (w.sum() + self.eps)

        fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
        return self.conv(fused)


class BiFPNLayer(nn.Module):
    """
    One round of BiFPN: top-down + bottom-up bidirectional fusion.

    Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
    """

    def __init__(self, channels: int):
        super().__init__()
        # Top-down fusion nodes (P5 β†’ P4_td, P4_td+P3 β†’ P3_td, P3_td+P2 β†’ P2_td)
        self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
        self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
        self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)

        # Bottom-up fusion nodes (combine top-down outputs with original)
        self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3)  # p3_orig + p3_td + p2_out
        self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3)  # p4_orig + p4_td + p3_out
        self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2)  # p5_orig + p4_out

    def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
        """
        Args:
            features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims

        Returns:
            [P2_out, P3_out, P4_out, P5_out]
        """
        p2, p3, p4, p5 = features

        # --- Top-down pathway ---
        # P5 β†’ upscale β†’ fuse with P4
        p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
        p4_td = self.td_p4([p4, p5_up])

        # P4_td β†’ upscale β†’ fuse with P3
        p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
        p3_td = self.td_p3([p3, p4_td_up])

        # P3_td β†’ upscale β†’ fuse with P2
        p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
        p2_td = self.td_p2([p2, p3_td_up])

        # --- Bottom-up pathway ---
        p2_out = p2_td

        # P2_out β†’ downsample β†’ fuse with P3_td and P3_orig
        p2_down = F.max_pool2d(p2_out, kernel_size=2)
        p3_out = self.bu_p3([p3, p3_td, p2_down])

        # P3_out β†’ downsample β†’ fuse with P4_td and P4_orig
        p3_down = F.max_pool2d(p3_out, kernel_size=2)
        p4_out = self.bu_p4([p4, p4_td, p3_down])

        # P4_out β†’ downsample β†’ fuse with P5_orig
        p4_down = F.max_pool2d(p4_out, kernel_size=2)
        p5_out = self.bu_p5([p5, p4_down])

        return [p2_out, p3_out, p4_out, p5_out]


class BiFPN(nn.Module):
    """Multi-round BiFPN with lateral projections."""

    def __init__(self, in_channels: List[int], out_channels: int = 128,
                 num_rounds: int = 2):
        super().__init__()
        # Lateral 1x1 projections to unify channel count
        self.laterals = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
            for in_ch in in_channels
        ])

        # BiFPN rounds
        self.rounds = nn.ModuleList([
            BiFPNLayer(out_channels) for _ in range(num_rounds)
        ])

    def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
        # Project to uniform channels
        projected = [lat(feat) for lat, feat in zip(self.laterals, features)]

        # Run BiFPN rounds
        for bifpn_round in self.rounds:
            projected = bifpn_round(projected)

        return projected


# ---------------------------------------------------------------------------
# Detection Heads
# ---------------------------------------------------------------------------

class HeatmapHead(nn.Module):
    """Heatmap prediction head at stride-2 resolution."""

    def __init__(self, in_channels: int = 64, num_classes: int = 2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)

        # Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
        # This prevents the network from producing high false positive rate early
        nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.bn1(self.conv1(x)))
        return torch.sigmoid(self.conv2(x))


class OffsetHead(nn.Module):
    """Sub-pixel offset regression head."""

    def __init__(self, in_channels: int = 64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 2, kernel_size=1)  # dx, dy

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.bn1(self.conv1(x)))
        return self.conv2(x)


# ---------------------------------------------------------------------------
# Full CenterNet Model
# ---------------------------------------------------------------------------

class ImmunogoldCenterNet(nn.Module):
    """
    CenterNet with CEM500K-pretrained ResNet-50 backbone.

    Detects 6nm and 12nm immunogold particles at stride-2 resolution.
    """

    def __init__(
        self,
        pretrained_path: Optional[str] = None,
        bifpn_channels: int = 128,
        bifpn_rounds: int = 2,
        num_classes: int = 2,
        imagenet_encoder_fallback: bool = True,
    ):
        super().__init__()
        self.num_classes = num_classes

        # --- Encoder: ResNet-50 ---
        backbone = models.resnet50(weights=None)
        # Adapt conv1 for 1-channel grayscale input
        backbone.conv1 = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False,
        )

        # Load pretrained weights
        if pretrained_path:
            self._load_pretrained(backbone, pretrained_path)
        elif imagenet_encoder_fallback:
            # Training: better init when CEM500K path is missing (downloads ~100MB).
            imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
            state = imagenet_backbone.state_dict()
            # Mean-pool RGB conv1 weights β†’ grayscale
            state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
            backbone.load_state_dict(state, strict=False)
        # else: random encoder init β€” use when loading a full checkpoint immediately (Gradio, predict).

        # Extract encoder stages
        self.stem = nn.Sequential(
            backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
        )
        self.layer1 = backbone.layer1  # 256ch, stride 4
        self.layer2 = backbone.layer2  # 512ch, stride 8
        self.layer3 = backbone.layer3  # 1024ch, stride 16
        self.layer4 = backbone.layer4  # 2048ch, stride 32

        # --- BiFPN Neck ---
        self.bifpn = BiFPN(
            in_channels=[256, 512, 1024, 2048],
            out_channels=bifpn_channels,
            num_rounds=bifpn_rounds,
        )

        # --- Decoder: upsample P2 (stride 4) β†’ stride 2 ---
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(
                bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # --- Detection Heads (at stride-2 resolution) ---
        self.heatmap_head = HeatmapHead(64, num_classes)
        self.offset_head = OffsetHead(64)

    def _load_pretrained(self, backbone: nn.Module, path: str):
        """Load CEM500K MoCoV2 pretrained weights."""
        ckpt = torch.load(path, map_location="cpu", weights_only=False)

        state = {}
        # CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
        src_state = ckpt.get("state_dict", ckpt)
        for k, v in src_state.items():
            # Strip MoCo prefix
            new_key = k
            for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
                if new_key.startswith(prefix):
                    new_key = new_key[len(prefix):]
                    break
            state[new_key] = v

        # Adapt conv1: mean-pool 3ch RGB β†’ 1ch grayscale
        if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
            state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)

        # Load with strict=False (head layers won't match)
        missing, unexpected = backbone.load_state_dict(state, strict=False)
        # Expected: fc.weight, fc.bias will be missing/unexpected
        print(f"CEM500K loaded: {len(state)} keys, "
              f"{len(missing)} missing, {len(unexpected)} unexpected")

    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: (B, 1, H, W) grayscale input

        Returns:
            heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
            offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
        """
        # Encoder
        x0 = self.stem(x)        # stride 4
        p2 = self.layer1(x0)     # 256ch, stride 4
        p3 = self.layer2(p2)     # 512ch, stride 8
        p4 = self.layer3(p3)     # 1024ch, stride 16
        p5 = self.layer4(p4)     # 2048ch, stride 32

        # BiFPN neck
        features = self.bifpn([p2, p3, p4, p5])

        # Decoder: upsample P2 to stride 2
        x_up = self.upsample(features[0])

        # Detection heads
        heatmap = self.heatmap_head(x_up)   # (B, 2, H/2, W/2)
        offsets = self.offset_head(x_up)    # (B, 2, H/2, W/2)

        return heatmap, offsets

    def freeze_encoder(self):
        """Freeze entire encoder (Phase 1 training)."""
        for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
            for param in module.parameters():
                param.requires_grad = False

    def unfreeze_deep_layers(self):
        """Unfreeze layer3 and layer4 (Phase 2 training)."""
        for module in [self.layer3, self.layer4]:
            for param in module.parameters():
                param.requires_grad = True

    def unfreeze_all(self):
        """Unfreeze all layers (Phase 3 training)."""
        for param in self.parameters():
            param.requires_grad = True

    def get_param_groups(self, phase: int, cfg: dict) -> list:
        """
        Get parameter groups with discriminative learning rates per phase.

        Args:
            phase: 1, 2, or 3
            cfg: training phase config from config.yaml

        Returns:
            List of param group dicts for optimizer.
        """
        if phase == 1:
            # Only neck + heads trainable
            return [
                {"params": self.bifpn.parameters(), "lr": cfg["lr"]},
                {"params": self.upsample.parameters(), "lr": cfg["lr"]},
                {"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
                {"params": self.offset_head.parameters(), "lr": cfg["lr"]},
            ]
        elif phase == 2:
            return [
                {"params": self.stem.parameters(), "lr": 0},
                {"params": self.layer1.parameters(), "lr": 0},
                {"params": self.layer2.parameters(), "lr": 0},
                {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
                {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
                {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
            ]
        else:  # phase 3
            return [
                {"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
                {"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
                {"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
                {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
                {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
                {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
                {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
            ]