AnikS22 commited on
Commit
88a76dc
·
verified ·
1 Parent(s): 3bb0c87

Upload src/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model.py +382 -0
src/model.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.
3
+
4
+ Architecture:
5
+ Input: 1ch grayscale, variable size (padded to multiple of 32)
6
+ Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
7
+ Neck: BiFPN (2 rounds, 128ch)
8
+ Decoder: Transposed conv → stride-2 output
9
+ Heads: Heatmap (2ch sigmoid), Offset (2ch)
10
+ Output: Stride-2 maps → (H/2, W/2) resolution
11
+
12
+ Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
13
+ collapses to 1px in feature space — insufficient for detection.
14
+ At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
15
+ """
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision.models as models
22
+ from typing import List, Optional
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # BiFPN: Bidirectional Feature Pyramid Network
27
+ # ---------------------------------------------------------------------------
28
+
29
+ class DepthwiseSeparableConv(nn.Module):
30
+ """Depthwise separable convolution as used in BiFPN."""
31
+
32
+ def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
33
+ stride: int = 1, padding: int = 1):
34
+ super().__init__()
35
+ self.depthwise = nn.Conv2d(
36
+ in_ch, in_ch, kernel_size, stride=stride,
37
+ padding=padding, groups=in_ch, bias=False,
38
+ )
39
+ self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
40
+ self.bn = nn.BatchNorm2d(out_ch)
41
+ self.act = nn.ReLU(inplace=True)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ return self.act(self.bn(self.pointwise(self.depthwise(x))))
45
+
46
+
47
+ class BiFPNFusionNode(nn.Module):
48
+ """
49
+ Single BiFPN fusion node with fast normalized weighted fusion.
50
+
51
+ w_normalized = relu(w) / (sum(relu(w)) + eps)
52
+ output = conv(sum(w_i * input_i))
53
+ """
54
+
55
+ def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
56
+ super().__init__()
57
+ self.eps = eps
58
+ # Learnable fusion weights
59
+ self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
60
+ self.conv = DepthwiseSeparableConv(channels, channels)
61
+
62
+ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
63
+ # Fast normalized fusion
64
+ w = F.relu(self.weights)
65
+ w_norm = w / (w.sum() + self.eps)
66
+
67
+ fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
68
+ return self.conv(fused)
69
+
70
+
71
+ class BiFPNLayer(nn.Module):
72
+ """
73
+ One round of BiFPN: top-down + bottom-up bidirectional fusion.
74
+
75
+ Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
76
+ """
77
+
78
+ def __init__(self, channels: int):
79
+ super().__init__()
80
+ # Top-down fusion nodes (P5 → P4_td, P4_td+P3 → P3_td, P3_td+P2 → P2_td)
81
+ self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
82
+ self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
83
+ self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)
84
+
85
+ # Bottom-up fusion nodes (combine top-down outputs with original)
86
+ self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3) # p3_orig + p3_td + p2_out
87
+ self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3) # p4_orig + p4_td + p3_out
88
+ self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2) # p5_orig + p4_out
89
+
90
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
91
+ """
92
+ Args:
93
+ features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims
94
+
95
+ Returns:
96
+ [P2_out, P3_out, P4_out, P5_out]
97
+ """
98
+ p2, p3, p4, p5 = features
99
+
100
+ # --- Top-down pathway ---
101
+ # P5 → upscale → fuse with P4
102
+ p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
103
+ p4_td = self.td_p4([p4, p5_up])
104
+
105
+ # P4_td → upscale → fuse with P3
106
+ p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
107
+ p3_td = self.td_p3([p3, p4_td_up])
108
+
109
+ # P3_td → upscale → fuse with P2
110
+ p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
111
+ p2_td = self.td_p2([p2, p3_td_up])
112
+
113
+ # --- Bottom-up pathway ---
114
+ p2_out = p2_td
115
+
116
+ # P2_out → downsample → fuse with P3_td and P3_orig
117
+ p2_down = F.max_pool2d(p2_out, kernel_size=2)
118
+ p3_out = self.bu_p3([p3, p3_td, p2_down])
119
+
120
+ # P3_out → downsample → fuse with P4_td and P4_orig
121
+ p3_down = F.max_pool2d(p3_out, kernel_size=2)
122
+ p4_out = self.bu_p4([p4, p4_td, p3_down])
123
+
124
+ # P4_out → downsample → fuse with P5_orig
125
+ p4_down = F.max_pool2d(p4_out, kernel_size=2)
126
+ p5_out = self.bu_p5([p5, p4_down])
127
+
128
+ return [p2_out, p3_out, p4_out, p5_out]
129
+
130
+
131
+ class BiFPN(nn.Module):
132
+ """Multi-round BiFPN with lateral projections."""
133
+
134
+ def __init__(self, in_channels: List[int], out_channels: int = 128,
135
+ num_rounds: int = 2):
136
+ super().__init__()
137
+ # Lateral 1x1 projections to unify channel count
138
+ self.laterals = nn.ModuleList([
139
+ nn.Sequential(
140
+ nn.Conv2d(in_ch, out_channels, 1, bias=False),
141
+ nn.BatchNorm2d(out_channels),
142
+ nn.ReLU(inplace=True),
143
+ )
144
+ for in_ch in in_channels
145
+ ])
146
+
147
+ # BiFPN rounds
148
+ self.rounds = nn.ModuleList([
149
+ BiFPNLayer(out_channels) for _ in range(num_rounds)
150
+ ])
151
+
152
+ def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
153
+ # Project to uniform channels
154
+ projected = [lat(feat) for lat, feat in zip(self.laterals, features)]
155
+
156
+ # Run BiFPN rounds
157
+ for bifpn_round in self.rounds:
158
+ projected = bifpn_round(projected)
159
+
160
+ return projected
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Detection Heads
165
+ # ---------------------------------------------------------------------------
166
+
167
+ class HeatmapHead(nn.Module):
168
+ """Heatmap prediction head at stride-2 resolution."""
169
+
170
+ def __init__(self, in_channels: int = 64, num_classes: int = 2):
171
+ super().__init__()
172
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
173
+ self.bn1 = nn.BatchNorm2d(64)
174
+ self.relu = nn.ReLU(inplace=True)
175
+ self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
176
+
177
+ # Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
178
+ # This prevents the network from producing high false positive rate early
179
+ nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ x = self.relu(self.bn1(self.conv1(x)))
183
+ return torch.sigmoid(self.conv2(x))
184
+
185
+
186
+ class OffsetHead(nn.Module):
187
+ """Sub-pixel offset regression head."""
188
+
189
+ def __init__(self, in_channels: int = 64):
190
+ super().__init__()
191
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
192
+ self.bn1 = nn.BatchNorm2d(64)
193
+ self.relu = nn.ReLU(inplace=True)
194
+ self.conv2 = nn.Conv2d(64, 2, kernel_size=1) # dx, dy
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ x = self.relu(self.bn1(self.conv1(x)))
198
+ return self.conv2(x)
199
+
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Full CenterNet Model
203
+ # ---------------------------------------------------------------------------
204
+
205
+ class ImmunogoldCenterNet(nn.Module):
206
+ """
207
+ CenterNet with CEM500K-pretrained ResNet-50 backbone.
208
+
209
+ Detects 6nm and 12nm immunogold particles at stride-2 resolution.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ pretrained_path: Optional[str] = None,
215
+ bifpn_channels: int = 128,
216
+ bifpn_rounds: int = 2,
217
+ num_classes: int = 2,
218
+ ):
219
+ super().__init__()
220
+ self.num_classes = num_classes
221
+
222
+ # --- Encoder: ResNet-50 ---
223
+ backbone = models.resnet50(weights=None)
224
+ # Adapt conv1 for 1-channel grayscale input
225
+ backbone.conv1 = nn.Conv2d(
226
+ 1, 64, kernel_size=7, stride=2, padding=3, bias=False,
227
+ )
228
+
229
+ # Load pretrained weights
230
+ if pretrained_path:
231
+ self._load_pretrained(backbone, pretrained_path)
232
+ else:
233
+ # Use ImageNet weights as fallback, adapting conv1
234
+ imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
235
+ state = imagenet_backbone.state_dict()
236
+ # Mean-pool RGB conv1 weights → grayscale
237
+ state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
238
+ backbone.load_state_dict(state, strict=False)
239
+
240
+ # Extract encoder stages
241
+ self.stem = nn.Sequential(
242
+ backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
243
+ )
244
+ self.layer1 = backbone.layer1 # 256ch, stride 4
245
+ self.layer2 = backbone.layer2 # 512ch, stride 8
246
+ self.layer3 = backbone.layer3 # 1024ch, stride 16
247
+ self.layer4 = backbone.layer4 # 2048ch, stride 32
248
+
249
+ # --- BiFPN Neck ---
250
+ self.bifpn = BiFPN(
251
+ in_channels=[256, 512, 1024, 2048],
252
+ out_channels=bifpn_channels,
253
+ num_rounds=bifpn_rounds,
254
+ )
255
+
256
+ # --- Decoder: upsample P2 (stride 4) → stride 2 ---
257
+ self.upsample = nn.Sequential(
258
+ nn.ConvTranspose2d(
259
+ bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
260
+ ),
261
+ nn.BatchNorm2d(64),
262
+ nn.ReLU(inplace=True),
263
+ )
264
+
265
+ # --- Detection Heads (at stride-2 resolution) ---
266
+ self.heatmap_head = HeatmapHead(64, num_classes)
267
+ self.offset_head = OffsetHead(64)
268
+
269
+ def _load_pretrained(self, backbone: nn.Module, path: str):
270
+ """Load CEM500K MoCoV2 pretrained weights."""
271
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
272
+
273
+ state = {}
274
+ # CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
275
+ src_state = ckpt.get("state_dict", ckpt)
276
+ for k, v in src_state.items():
277
+ # Strip MoCo prefix
278
+ new_key = k
279
+ for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
280
+ if new_key.startswith(prefix):
281
+ new_key = new_key[len(prefix):]
282
+ break
283
+ state[new_key] = v
284
+
285
+ # Adapt conv1: mean-pool 3ch RGB → 1ch grayscale
286
+ if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
287
+ state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
288
+
289
+ # Load with strict=False (head layers won't match)
290
+ missing, unexpected = backbone.load_state_dict(state, strict=False)
291
+ # Expected: fc.weight, fc.bias will be missing/unexpected
292
+ print(f"CEM500K loaded: {len(state)} keys, "
293
+ f"{len(missing)} missing, {len(unexpected)} unexpected")
294
+
295
+ def forward(self, x: torch.Tensor) -> tuple:
296
+ """
297
+ Args:
298
+ x: (B, 1, H, W) grayscale input
299
+
300
+ Returns:
301
+ heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
302
+ offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
303
+ """
304
+ # Encoder
305
+ x0 = self.stem(x) # stride 4
306
+ p2 = self.layer1(x0) # 256ch, stride 4
307
+ p3 = self.layer2(p2) # 512ch, stride 8
308
+ p4 = self.layer3(p3) # 1024ch, stride 16
309
+ p5 = self.layer4(p4) # 2048ch, stride 32
310
+
311
+ # BiFPN neck
312
+ features = self.bifpn([p2, p3, p4, p5])
313
+
314
+ # Decoder: upsample P2 to stride 2
315
+ x_up = self.upsample(features[0])
316
+
317
+ # Detection heads
318
+ heatmap = self.heatmap_head(x_up) # (B, 2, H/2, W/2)
319
+ offsets = self.offset_head(x_up) # (B, 2, H/2, W/2)
320
+
321
+ return heatmap, offsets
322
+
323
+ def freeze_encoder(self):
324
+ """Freeze entire encoder (Phase 1 training)."""
325
+ for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
326
+ for param in module.parameters():
327
+ param.requires_grad = False
328
+
329
+ def unfreeze_deep_layers(self):
330
+ """Unfreeze layer3 and layer4 (Phase 2 training)."""
331
+ for module in [self.layer3, self.layer4]:
332
+ for param in module.parameters():
333
+ param.requires_grad = True
334
+
335
+ def unfreeze_all(self):
336
+ """Unfreeze all layers (Phase 3 training)."""
337
+ for param in self.parameters():
338
+ param.requires_grad = True
339
+
340
+ def get_param_groups(self, phase: int, cfg: dict) -> list:
341
+ """
342
+ Get parameter groups with discriminative learning rates per phase.
343
+
344
+ Args:
345
+ phase: 1, 2, or 3
346
+ cfg: training phase config from config.yaml
347
+
348
+ Returns:
349
+ List of param group dicts for optimizer.
350
+ """
351
+ if phase == 1:
352
+ # Only neck + heads trainable
353
+ return [
354
+ {"params": self.bifpn.parameters(), "lr": cfg["lr"]},
355
+ {"params": self.upsample.parameters(), "lr": cfg["lr"]},
356
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
357
+ {"params": self.offset_head.parameters(), "lr": cfg["lr"]},
358
+ ]
359
+ elif phase == 2:
360
+ return [
361
+ {"params": self.stem.parameters(), "lr": 0},
362
+ {"params": self.layer1.parameters(), "lr": 0},
363
+ {"params": self.layer2.parameters(), "lr": 0},
364
+ {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
365
+ {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
366
+ {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
367
+ {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
368
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
369
+ {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
370
+ ]
371
+ else: # phase 3
372
+ return [
373
+ {"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
374
+ {"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
375
+ {"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
376
+ {"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
377
+ {"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
378
+ {"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
379
+ {"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
380
+ {"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
381
+ {"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
382
+ ]