Upload src/model.py with huggingface_hub
Browse files- src/model.py +382 -0
src/model.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.
|
| 3 |
+
|
| 4 |
+
Architecture:
|
| 5 |
+
Input: 1ch grayscale, variable size (padded to multiple of 32)
|
| 6 |
+
Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
|
| 7 |
+
Neck: BiFPN (2 rounds, 128ch)
|
| 8 |
+
Decoder: Transposed conv → stride-2 output
|
| 9 |
+
Heads: Heatmap (2ch sigmoid), Offset (2ch)
|
| 10 |
+
Output: Stride-2 maps → (H/2, W/2) resolution
|
| 11 |
+
|
| 12 |
+
Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
|
| 13 |
+
collapses to 1px in feature space — insufficient for detection.
|
| 14 |
+
At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torchvision.models as models
|
| 22 |
+
from typing import List, Optional
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# BiFPN: Bidirectional Feature Pyramid Network
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 30 |
+
"""Depthwise separable convolution as used in BiFPN."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
|
| 33 |
+
stride: int = 1, padding: int = 1):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.depthwise = nn.Conv2d(
|
| 36 |
+
in_ch, in_ch, kernel_size, stride=stride,
|
| 37 |
+
padding=padding, groups=in_ch, bias=False,
|
| 38 |
+
)
|
| 39 |
+
self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
|
| 40 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 41 |
+
self.act = nn.ReLU(inplace=True)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
return self.act(self.bn(self.pointwise(self.depthwise(x))))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BiFPNFusionNode(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
Single BiFPN fusion node with fast normalized weighted fusion.
|
| 50 |
+
|
| 51 |
+
w_normalized = relu(w) / (sum(relu(w)) + eps)
|
| 52 |
+
output = conv(sum(w_i * input_i))
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.eps = eps
|
| 58 |
+
# Learnable fusion weights
|
| 59 |
+
self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
|
| 60 |
+
self.conv = DepthwiseSeparableConv(channels, channels)
|
| 61 |
+
|
| 62 |
+
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
| 63 |
+
# Fast normalized fusion
|
| 64 |
+
w = F.relu(self.weights)
|
| 65 |
+
w_norm = w / (w.sum() + self.eps)
|
| 66 |
+
|
| 67 |
+
fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
|
| 68 |
+
return self.conv(fused)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class BiFPNLayer(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
One round of BiFPN: top-down + bottom-up bidirectional fusion.
|
| 74 |
+
|
| 75 |
+
Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, channels: int):
|
| 79 |
+
super().__init__()
|
| 80 |
+
# Top-down fusion nodes (P5 → P4_td, P4_td+P3 → P3_td, P3_td+P2 → P2_td)
|
| 81 |
+
self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
|
| 82 |
+
self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
|
| 83 |
+
self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)
|
| 84 |
+
|
| 85 |
+
# Bottom-up fusion nodes (combine top-down outputs with original)
|
| 86 |
+
self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3) # p3_orig + p3_td + p2_out
|
| 87 |
+
self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3) # p4_orig + p4_td + p3_out
|
| 88 |
+
self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2) # p5_orig + p4_out
|
| 89 |
+
|
| 90 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 91 |
+
"""
|
| 92 |
+
Args:
|
| 93 |
+
features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
[P2_out, P3_out, P4_out, P5_out]
|
| 97 |
+
"""
|
| 98 |
+
p2, p3, p4, p5 = features
|
| 99 |
+
|
| 100 |
+
# --- Top-down pathway ---
|
| 101 |
+
# P5 → upscale → fuse with P4
|
| 102 |
+
p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
|
| 103 |
+
p4_td = self.td_p4([p4, p5_up])
|
| 104 |
+
|
| 105 |
+
# P4_td → upscale → fuse with P3
|
| 106 |
+
p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
|
| 107 |
+
p3_td = self.td_p3([p3, p4_td_up])
|
| 108 |
+
|
| 109 |
+
# P3_td → upscale → fuse with P2
|
| 110 |
+
p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
|
| 111 |
+
p2_td = self.td_p2([p2, p3_td_up])
|
| 112 |
+
|
| 113 |
+
# --- Bottom-up pathway ---
|
| 114 |
+
p2_out = p2_td
|
| 115 |
+
|
| 116 |
+
# P2_out → downsample → fuse with P3_td and P3_orig
|
| 117 |
+
p2_down = F.max_pool2d(p2_out, kernel_size=2)
|
| 118 |
+
p3_out = self.bu_p3([p3, p3_td, p2_down])
|
| 119 |
+
|
| 120 |
+
# P3_out → downsample → fuse with P4_td and P4_orig
|
| 121 |
+
p3_down = F.max_pool2d(p3_out, kernel_size=2)
|
| 122 |
+
p4_out = self.bu_p4([p4, p4_td, p3_down])
|
| 123 |
+
|
| 124 |
+
# P4_out → downsample → fuse with P5_orig
|
| 125 |
+
p4_down = F.max_pool2d(p4_out, kernel_size=2)
|
| 126 |
+
p5_out = self.bu_p5([p5, p4_down])
|
| 127 |
+
|
| 128 |
+
return [p2_out, p3_out, p4_out, p5_out]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class BiFPN(nn.Module):
|
| 132 |
+
"""Multi-round BiFPN with lateral projections."""
|
| 133 |
+
|
| 134 |
+
def __init__(self, in_channels: List[int], out_channels: int = 128,
|
| 135 |
+
num_rounds: int = 2):
|
| 136 |
+
super().__init__()
|
| 137 |
+
# Lateral 1x1 projections to unify channel count
|
| 138 |
+
self.laterals = nn.ModuleList([
|
| 139 |
+
nn.Sequential(
|
| 140 |
+
nn.Conv2d(in_ch, out_channels, 1, bias=False),
|
| 141 |
+
nn.BatchNorm2d(out_channels),
|
| 142 |
+
nn.ReLU(inplace=True),
|
| 143 |
+
)
|
| 144 |
+
for in_ch in in_channels
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
# BiFPN rounds
|
| 148 |
+
self.rounds = nn.ModuleList([
|
| 149 |
+
BiFPNLayer(out_channels) for _ in range(num_rounds)
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 153 |
+
# Project to uniform channels
|
| 154 |
+
projected = [lat(feat) for lat, feat in zip(self.laterals, features)]
|
| 155 |
+
|
| 156 |
+
# Run BiFPN rounds
|
| 157 |
+
for bifpn_round in self.rounds:
|
| 158 |
+
projected = bifpn_round(projected)
|
| 159 |
+
|
| 160 |
+
return projected
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Detection Heads
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
class HeatmapHead(nn.Module):
|
| 168 |
+
"""Heatmap prediction head at stride-2 resolution."""
|
| 169 |
+
|
| 170 |
+
def __init__(self, in_channels: int = 64, num_classes: int = 2):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
|
| 173 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 174 |
+
self.relu = nn.ReLU(inplace=True)
|
| 175 |
+
self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
|
| 176 |
+
|
| 177 |
+
# Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
|
| 178 |
+
# This prevents the network from producing high false positive rate early
|
| 179 |
+
nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 183 |
+
return torch.sigmoid(self.conv2(x))
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class OffsetHead(nn.Module):
|
| 187 |
+
"""Sub-pixel offset regression head."""
|
| 188 |
+
|
| 189 |
+
def __init__(self, in_channels: int = 64):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
|
| 192 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 193 |
+
self.relu = nn.ReLU(inplace=True)
|
| 194 |
+
self.conv2 = nn.Conv2d(64, 2, kernel_size=1) # dx, dy
|
| 195 |
+
|
| 196 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 197 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 198 |
+
return self.conv2(x)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
# Full CenterNet Model
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
class ImmunogoldCenterNet(nn.Module):
|
| 206 |
+
"""
|
| 207 |
+
CenterNet with CEM500K-pretrained ResNet-50 backbone.
|
| 208 |
+
|
| 209 |
+
Detects 6nm and 12nm immunogold particles at stride-2 resolution.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
pretrained_path: Optional[str] = None,
|
| 215 |
+
bifpn_channels: int = 128,
|
| 216 |
+
bifpn_rounds: int = 2,
|
| 217 |
+
num_classes: int = 2,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.num_classes = num_classes
|
| 221 |
+
|
| 222 |
+
# --- Encoder: ResNet-50 ---
|
| 223 |
+
backbone = models.resnet50(weights=None)
|
| 224 |
+
# Adapt conv1 for 1-channel grayscale input
|
| 225 |
+
backbone.conv1 = nn.Conv2d(
|
| 226 |
+
1, 64, kernel_size=7, stride=2, padding=3, bias=False,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Load pretrained weights
|
| 230 |
+
if pretrained_path:
|
| 231 |
+
self._load_pretrained(backbone, pretrained_path)
|
| 232 |
+
else:
|
| 233 |
+
# Use ImageNet weights as fallback, adapting conv1
|
| 234 |
+
imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 235 |
+
state = imagenet_backbone.state_dict()
|
| 236 |
+
# Mean-pool RGB conv1 weights → grayscale
|
| 237 |
+
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 238 |
+
backbone.load_state_dict(state, strict=False)
|
| 239 |
+
|
| 240 |
+
# Extract encoder stages
|
| 241 |
+
self.stem = nn.Sequential(
|
| 242 |
+
backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
|
| 243 |
+
)
|
| 244 |
+
self.layer1 = backbone.layer1 # 256ch, stride 4
|
| 245 |
+
self.layer2 = backbone.layer2 # 512ch, stride 8
|
| 246 |
+
self.layer3 = backbone.layer3 # 1024ch, stride 16
|
| 247 |
+
self.layer4 = backbone.layer4 # 2048ch, stride 32
|
| 248 |
+
|
| 249 |
+
# --- BiFPN Neck ---
|
| 250 |
+
self.bifpn = BiFPN(
|
| 251 |
+
in_channels=[256, 512, 1024, 2048],
|
| 252 |
+
out_channels=bifpn_channels,
|
| 253 |
+
num_rounds=bifpn_rounds,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# --- Decoder: upsample P2 (stride 4) → stride 2 ---
|
| 257 |
+
self.upsample = nn.Sequential(
|
| 258 |
+
nn.ConvTranspose2d(
|
| 259 |
+
bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
|
| 260 |
+
),
|
| 261 |
+
nn.BatchNorm2d(64),
|
| 262 |
+
nn.ReLU(inplace=True),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# --- Detection Heads (at stride-2 resolution) ---
|
| 266 |
+
self.heatmap_head = HeatmapHead(64, num_classes)
|
| 267 |
+
self.offset_head = OffsetHead(64)
|
| 268 |
+
|
| 269 |
+
def _load_pretrained(self, backbone: nn.Module, path: str):
|
| 270 |
+
"""Load CEM500K MoCoV2 pretrained weights."""
|
| 271 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 272 |
+
|
| 273 |
+
state = {}
|
| 274 |
+
# CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
|
| 275 |
+
src_state = ckpt.get("state_dict", ckpt)
|
| 276 |
+
for k, v in src_state.items():
|
| 277 |
+
# Strip MoCo prefix
|
| 278 |
+
new_key = k
|
| 279 |
+
for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
|
| 280 |
+
if new_key.startswith(prefix):
|
| 281 |
+
new_key = new_key[len(prefix):]
|
| 282 |
+
break
|
| 283 |
+
state[new_key] = v
|
| 284 |
+
|
| 285 |
+
# Adapt conv1: mean-pool 3ch RGB → 1ch grayscale
|
| 286 |
+
if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
|
| 287 |
+
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
|
| 288 |
+
|
| 289 |
+
# Load with strict=False (head layers won't match)
|
| 290 |
+
missing, unexpected = backbone.load_state_dict(state, strict=False)
|
| 291 |
+
# Expected: fc.weight, fc.bias will be missing/unexpected
|
| 292 |
+
print(f"CEM500K loaded: {len(state)} keys, "
|
| 293 |
+
f"{len(missing)} missing, {len(unexpected)} unexpected")
|
| 294 |
+
|
| 295 |
+
def forward(self, x: torch.Tensor) -> tuple:
|
| 296 |
+
"""
|
| 297 |
+
Args:
|
| 298 |
+
x: (B, 1, H, W) grayscale input
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
|
| 302 |
+
offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
|
| 303 |
+
"""
|
| 304 |
+
# Encoder
|
| 305 |
+
x0 = self.stem(x) # stride 4
|
| 306 |
+
p2 = self.layer1(x0) # 256ch, stride 4
|
| 307 |
+
p3 = self.layer2(p2) # 512ch, stride 8
|
| 308 |
+
p4 = self.layer3(p3) # 1024ch, stride 16
|
| 309 |
+
p5 = self.layer4(p4) # 2048ch, stride 32
|
| 310 |
+
|
| 311 |
+
# BiFPN neck
|
| 312 |
+
features = self.bifpn([p2, p3, p4, p5])
|
| 313 |
+
|
| 314 |
+
# Decoder: upsample P2 to stride 2
|
| 315 |
+
x_up = self.upsample(features[0])
|
| 316 |
+
|
| 317 |
+
# Detection heads
|
| 318 |
+
heatmap = self.heatmap_head(x_up) # (B, 2, H/2, W/2)
|
| 319 |
+
offsets = self.offset_head(x_up) # (B, 2, H/2, W/2)
|
| 320 |
+
|
| 321 |
+
return heatmap, offsets
|
| 322 |
+
|
| 323 |
+
def freeze_encoder(self):
|
| 324 |
+
"""Freeze entire encoder (Phase 1 training)."""
|
| 325 |
+
for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
|
| 326 |
+
for param in module.parameters():
|
| 327 |
+
param.requires_grad = False
|
| 328 |
+
|
| 329 |
+
def unfreeze_deep_layers(self):
|
| 330 |
+
"""Unfreeze layer3 and layer4 (Phase 2 training)."""
|
| 331 |
+
for module in [self.layer3, self.layer4]:
|
| 332 |
+
for param in module.parameters():
|
| 333 |
+
param.requires_grad = True
|
| 334 |
+
|
| 335 |
+
def unfreeze_all(self):
|
| 336 |
+
"""Unfreeze all layers (Phase 3 training)."""
|
| 337 |
+
for param in self.parameters():
|
| 338 |
+
param.requires_grad = True
|
| 339 |
+
|
| 340 |
+
def get_param_groups(self, phase: int, cfg: dict) -> list:
|
| 341 |
+
"""
|
| 342 |
+
Get parameter groups with discriminative learning rates per phase.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
phase: 1, 2, or 3
|
| 346 |
+
cfg: training phase config from config.yaml
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
List of param group dicts for optimizer.
|
| 350 |
+
"""
|
| 351 |
+
if phase == 1:
|
| 352 |
+
# Only neck + heads trainable
|
| 353 |
+
return [
|
| 354 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr"]},
|
| 355 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr"]},
|
| 356 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
|
| 357 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr"]},
|
| 358 |
+
]
|
| 359 |
+
elif phase == 2:
|
| 360 |
+
return [
|
| 361 |
+
{"params": self.stem.parameters(), "lr": 0},
|
| 362 |
+
{"params": self.layer1.parameters(), "lr": 0},
|
| 363 |
+
{"params": self.layer2.parameters(), "lr": 0},
|
| 364 |
+
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
|
| 365 |
+
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
|
| 366 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
|
| 367 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
|
| 368 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 369 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 370 |
+
]
|
| 371 |
+
else: # phase 3
|
| 372 |
+
return [
|
| 373 |
+
{"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
|
| 374 |
+
{"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
|
| 375 |
+
{"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
|
| 376 |
+
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
|
| 377 |
+
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
|
| 378 |
+
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
|
| 379 |
+
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
|
| 380 |
+
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 381 |
+
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
|
| 382 |
+
]
|