MidasMap / src /model.py
AnikS22's picture
Deploy MidasMap Gradio app; weights downloaded from model repo at runtime
2a62959 verified
"""
CenterNet with CEM500K-pretrained ResNet-50 backbone for immunogold detection.
Architecture:
Input: 1ch grayscale, variable size (padded to multiple of 32)
Encoder: CEM500K ResNet-50 (pretrained), conv1 adapted for 1ch input
Neck: BiFPN (2 rounds, 128ch)
Decoder: Transposed conv β†’ stride-2 output
Heads: Heatmap (2ch sigmoid), Offset (2ch)
Output: Stride-2 maps β†’ (H/2, W/2) resolution
Output stride is 2, NOT 4 or 8. At stride 4, a 6nm bead (4-6px radius)
collapses to 1px in feature space β€” insufficient for detection.
At stride 2, same bead occupies 2-3px, enough for Gaussian peak extraction.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import List, Optional
# ---------------------------------------------------------------------------
# BiFPN: Bidirectional Feature Pyramid Network
# ---------------------------------------------------------------------------
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution as used in BiFPN."""
def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3,
stride: int = 1, padding: int = 1):
super().__init__()
self.depthwise = nn.Conv2d(
in_ch, in_ch, kernel_size, stride=stride,
padding=padding, groups=in_ch, bias=False,
)
self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.act = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.pointwise(self.depthwise(x))))
class BiFPNFusionNode(nn.Module):
"""
Single BiFPN fusion node with fast normalized weighted fusion.
w_normalized = relu(w) / (sum(relu(w)) + eps)
output = conv(sum(w_i * input_i))
"""
def __init__(self, channels: int, n_inputs: int = 2, eps: float = 1e-4):
super().__init__()
self.eps = eps
# Learnable fusion weights
self.weights = nn.Parameter(torch.ones(n_inputs, dtype=torch.float32))
self.conv = DepthwiseSeparableConv(channels, channels)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
# Fast normalized fusion
w = F.relu(self.weights)
w_norm = w / (w.sum() + self.eps)
fused = sum(w_i * inp for w_i, inp in zip(w_norm, inputs))
return self.conv(fused)
class BiFPNLayer(nn.Module):
"""
One round of BiFPN: top-down + bottom-up bidirectional fusion.
Input levels: P2 (stride 4), P3 (stride 8), P4 (stride 16), P5 (stride 32)
"""
def __init__(self, channels: int):
super().__init__()
# Top-down fusion nodes (P5 β†’ P4_td, P4_td+P3 β†’ P3_td, P3_td+P2 β†’ P2_td)
self.td_p4 = BiFPNFusionNode(channels, n_inputs=2)
self.td_p3 = BiFPNFusionNode(channels, n_inputs=2)
self.td_p2 = BiFPNFusionNode(channels, n_inputs=2)
# Bottom-up fusion nodes (combine top-down outputs with original)
self.bu_p3 = BiFPNFusionNode(channels, n_inputs=3) # p3_orig + p3_td + p2_out
self.bu_p4 = BiFPNFusionNode(channels, n_inputs=3) # p4_orig + p4_td + p3_out
self.bu_p5 = BiFPNFusionNode(channels, n_inputs=2) # p5_orig + p4_out
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Args:
features: [P2, P3, P4, P5] at channels ch, with decreasing spatial dims
Returns:
[P2_out, P3_out, P4_out, P5_out]
"""
p2, p3, p4, p5 = features
# --- Top-down pathway ---
# P5 β†’ upscale β†’ fuse with P4
p5_up = F.interpolate(p5, size=p4.shape[2:], mode="nearest")
p4_td = self.td_p4([p4, p5_up])
# P4_td β†’ upscale β†’ fuse with P3
p4_td_up = F.interpolate(p4_td, size=p3.shape[2:], mode="nearest")
p3_td = self.td_p3([p3, p4_td_up])
# P3_td β†’ upscale β†’ fuse with P2
p3_td_up = F.interpolate(p3_td, size=p2.shape[2:], mode="nearest")
p2_td = self.td_p2([p2, p3_td_up])
# --- Bottom-up pathway ---
p2_out = p2_td
# P2_out β†’ downsample β†’ fuse with P3_td and P3_orig
p2_down = F.max_pool2d(p2_out, kernel_size=2)
p3_out = self.bu_p3([p3, p3_td, p2_down])
# P3_out β†’ downsample β†’ fuse with P4_td and P4_orig
p3_down = F.max_pool2d(p3_out, kernel_size=2)
p4_out = self.bu_p4([p4, p4_td, p3_down])
# P4_out β†’ downsample β†’ fuse with P5_orig
p4_down = F.max_pool2d(p4_out, kernel_size=2)
p5_out = self.bu_p5([p5, p4_down])
return [p2_out, p3_out, p4_out, p5_out]
class BiFPN(nn.Module):
"""Multi-round BiFPN with lateral projections."""
def __init__(self, in_channels: List[int], out_channels: int = 128,
num_rounds: int = 2):
super().__init__()
# Lateral 1x1 projections to unify channel count
self.laterals = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_ch, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for in_ch in in_channels
])
# BiFPN rounds
self.rounds = nn.ModuleList([
BiFPNLayer(out_channels) for _ in range(num_rounds)
])
def forward(self, features: List[torch.Tensor]) -> List[torch.Tensor]:
# Project to uniform channels
projected = [lat(feat) for lat, feat in zip(self.laterals, features)]
# Run BiFPN rounds
for bifpn_round in self.rounds:
projected = bifpn_round(projected)
return projected
# ---------------------------------------------------------------------------
# Detection Heads
# ---------------------------------------------------------------------------
class HeatmapHead(nn.Module):
"""Heatmap prediction head at stride-2 resolution."""
def __init__(self, in_channels: int = 64, num_classes: int = 2):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
# Initialize final conv bias for focal loss: -log((1-pi)/pi) where pi=0.01
# This prevents the network from producing high false positive rate early
nn.init.constant_(self.conv2.bias, -math.log((1 - 0.01) / 0.01))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.bn1(self.conv1(x)))
return torch.sigmoid(self.conv2(x))
class OffsetHead(nn.Module):
"""Sub-pixel offset regression head."""
def __init__(self, in_channels: int = 64):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 2, kernel_size=1) # dx, dy
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.bn1(self.conv1(x)))
return self.conv2(x)
# ---------------------------------------------------------------------------
# Full CenterNet Model
# ---------------------------------------------------------------------------
class ImmunogoldCenterNet(nn.Module):
"""
CenterNet with CEM500K-pretrained ResNet-50 backbone.
Detects 6nm and 12nm immunogold particles at stride-2 resolution.
"""
def __init__(
self,
pretrained_path: Optional[str] = None,
bifpn_channels: int = 128,
bifpn_rounds: int = 2,
num_classes: int = 2,
imagenet_encoder_fallback: bool = True,
):
super().__init__()
self.num_classes = num_classes
# --- Encoder: ResNet-50 ---
backbone = models.resnet50(weights=None)
# Adapt conv1 for 1-channel grayscale input
backbone.conv1 = nn.Conv2d(
1, 64, kernel_size=7, stride=2, padding=3, bias=False,
)
# Load pretrained weights
if pretrained_path:
self._load_pretrained(backbone, pretrained_path)
elif imagenet_encoder_fallback:
# Training: better init when CEM500K path is missing (downloads ~100MB).
imagenet_backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
state = imagenet_backbone.state_dict()
# Mean-pool RGB conv1 weights β†’ grayscale
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
backbone.load_state_dict(state, strict=False)
# else: random encoder init β€” use when loading a full checkpoint immediately (Gradio, predict).
# Extract encoder stages
self.stem = nn.Sequential(
backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool,
)
self.layer1 = backbone.layer1 # 256ch, stride 4
self.layer2 = backbone.layer2 # 512ch, stride 8
self.layer3 = backbone.layer3 # 1024ch, stride 16
self.layer4 = backbone.layer4 # 2048ch, stride 32
# --- BiFPN Neck ---
self.bifpn = BiFPN(
in_channels=[256, 512, 1024, 2048],
out_channels=bifpn_channels,
num_rounds=bifpn_rounds,
)
# --- Decoder: upsample P2 (stride 4) β†’ stride 2 ---
self.upsample = nn.Sequential(
nn.ConvTranspose2d(
bifpn_channels, 64, kernel_size=4, stride=2, padding=1, bias=False,
),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
# --- Detection Heads (at stride-2 resolution) ---
self.heatmap_head = HeatmapHead(64, num_classes)
self.offset_head = OffsetHead(64)
def _load_pretrained(self, backbone: nn.Module, path: str):
"""Load CEM500K MoCoV2 pretrained weights."""
ckpt = torch.load(path, map_location="cpu", weights_only=False)
state = {}
# CEM500K uses MoCo format: keys prefixed with 'module.encoder_q.'
src_state = ckpt.get("state_dict", ckpt)
for k, v in src_state.items():
# Strip MoCo prefix
new_key = k
for prefix in ["module.encoder_q.", "module.", "encoder_q."]:
if new_key.startswith(prefix):
new_key = new_key[len(prefix):]
break
state[new_key] = v
# Adapt conv1: mean-pool 3ch RGB β†’ 1ch grayscale
if "conv1.weight" in state and state["conv1.weight"].shape[1] == 3:
state["conv1.weight"] = state["conv1.weight"].mean(dim=1, keepdim=True)
# Load with strict=False (head layers won't match)
missing, unexpected = backbone.load_state_dict(state, strict=False)
# Expected: fc.weight, fc.bias will be missing/unexpected
print(f"CEM500K loaded: {len(state)} keys, "
f"{len(missing)} missing, {len(unexpected)} unexpected")
def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: (B, 1, H, W) grayscale input
Returns:
heatmap: (B, 2, H/2, W/2) sigmoid-activated class heatmaps
offsets: (B, 2, H/2, W/2) sub-pixel offset predictions
"""
# Encoder
x0 = self.stem(x) # stride 4
p2 = self.layer1(x0) # 256ch, stride 4
p3 = self.layer2(p2) # 512ch, stride 8
p4 = self.layer3(p3) # 1024ch, stride 16
p5 = self.layer4(p4) # 2048ch, stride 32
# BiFPN neck
features = self.bifpn([p2, p3, p4, p5])
# Decoder: upsample P2 to stride 2
x_up = self.upsample(features[0])
# Detection heads
heatmap = self.heatmap_head(x_up) # (B, 2, H/2, W/2)
offsets = self.offset_head(x_up) # (B, 2, H/2, W/2)
return heatmap, offsets
def freeze_encoder(self):
"""Freeze entire encoder (Phase 1 training)."""
for module in [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]:
for param in module.parameters():
param.requires_grad = False
def unfreeze_deep_layers(self):
"""Unfreeze layer3 and layer4 (Phase 2 training)."""
for module in [self.layer3, self.layer4]:
for param in module.parameters():
param.requires_grad = True
def unfreeze_all(self):
"""Unfreeze all layers (Phase 3 training)."""
for param in self.parameters():
param.requires_grad = True
def get_param_groups(self, phase: int, cfg: dict) -> list:
"""
Get parameter groups with discriminative learning rates per phase.
Args:
phase: 1, 2, or 3
cfg: training phase config from config.yaml
Returns:
List of param group dicts for optimizer.
"""
if phase == 1:
# Only neck + heads trainable
return [
{"params": self.bifpn.parameters(), "lr": cfg["lr"]},
{"params": self.upsample.parameters(), "lr": cfg["lr"]},
{"params": self.heatmap_head.parameters(), "lr": cfg["lr"]},
{"params": self.offset_head.parameters(), "lr": cfg["lr"]},
]
elif phase == 2:
return [
{"params": self.stem.parameters(), "lr": 0},
{"params": self.layer1.parameters(), "lr": 0},
{"params": self.layer2.parameters(), "lr": 0},
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
]
else: # phase 3
return [
{"params": self.stem.parameters(), "lr": cfg["lr_stem"]},
{"params": self.layer1.parameters(), "lr": cfg["lr_layer1"]},
{"params": self.layer2.parameters(), "lr": cfg["lr_layer2"]},
{"params": self.layer3.parameters(), "lr": cfg["lr_layer3"]},
{"params": self.layer4.parameters(), "lr": cfg["lr_layer4"]},
{"params": self.bifpn.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.upsample.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.heatmap_head.parameters(), "lr": cfg["lr_decoder"]},
{"params": self.offset_head.parameters(), "lr": cfg["lr_decoder"]},
]