bubble-distill / model.py
callumtilbury's picture
Add model architecture
88ec11f verified
"""
Tiny U-Net for microbubble segmentation (student model).
Architecture: Depthwise-separable U-Net with 4-channel output.
Inspired by PicoSAM2 (arxiv:2506.18807) β€” optimized for speed on narrow domains.
Output channels:
[0] dY β€” vertical gradient flow (Cellpose-compatible)
[1] dX β€” horizontal gradient flow (Cellpose-compatible)
[2] cell_prob β€” foreground probability (sigmoid)
[3] dist_transform β€” distance transform (encodes bubble radius)
Params: ~389K (base_ch=16), ~1.5M (base_ch=32)
"""
import torch
import torch.nn as nn
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution: depthwise + pointwise.
~8-9x fewer parameters than standard conv for same channel count."""
def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, stride=1):
super().__init__()
self.depthwise = nn.Conv2d(
in_ch, in_ch, kernel_size, stride=stride, padding=padding, groups=in_ch, bias=False
)
self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.pointwise(self.depthwise(x))))
class ConvBlock(nn.Module):
"""Double convolution block with optional depthwise-separable convs."""
def __init__(self, in_ch, out_ch, use_depthwise=True):
super().__init__()
if use_depthwise:
self.block = nn.Sequential(
DepthwiseSeparableConv(in_ch, out_ch),
DepthwiseSeparableConv(out_ch, out_ch),
)
else:
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class TinyBubbleNet(nn.Module):
"""
Tiny U-Net for microbubble instance segmentation.
Args:
in_channels: Input image channels (1 for grayscale, 3 for RGB)
base_ch: Base channel count (16 β†’ ~389K params, 32 β†’ ~1.5M params)
out_channels: Output channels (default 4: dY, dX, cell_prob, dist_transform)
use_depthwise: Use depthwise-separable convs (True for speed, False for accuracy)
"""
def __init__(self, in_channels=1, base_ch=16, out_channels=4, use_depthwise=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Encoder
self.enc1 = ConvBlock(in_channels, base_ch, use_depthwise=False) # first block: standard conv
self.enc2 = ConvBlock(base_ch, base_ch * 2, use_depthwise)
self.enc3 = ConvBlock(base_ch * 2, base_ch * 4, use_depthwise)
self.enc4 = ConvBlock(base_ch * 4, base_ch * 8, use_depthwise)
self.pool = nn.MaxPool2d(2)
# Bottleneck
self.bottleneck = ConvBlock(base_ch * 8, base_ch * 16, use_depthwise)
# Decoder
self.up4 = nn.ConvTranspose2d(base_ch * 16, base_ch * 8, kernel_size=2, stride=2)
self.dec4 = ConvBlock(base_ch * 16, base_ch * 8, use_depthwise)
self.up3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, kernel_size=2, stride=2)
self.dec3 = ConvBlock(base_ch * 8, base_ch * 4, use_depthwise)
self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, kernel_size=2, stride=2)
self.dec2 = ConvBlock(base_ch * 4, base_ch * 2, use_depthwise)
self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, kernel_size=2, stride=2)
self.dec1 = ConvBlock(base_ch * 2, base_ch, use_depthwise)
# Output heads (separate for interpretability)
self.flow_head = nn.Conv2d(base_ch, 2, 1) # dY, dX
self.prob_head = nn.Conv2d(base_ch, 1, 1) # cell probability
self.dist_head = nn.Conv2d(base_ch, 1, 1) # distance transform
def forward(self, x):
# Encoder path
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
# Bottleneck
b = self.bottleneck(self.pool(e4))
# Decoder path with skip connections
d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
# Output heads
flows = self.flow_head(d1) # (B, 2, H, W) β€” dY, dX
prob = self.prob_head(d1) # (B, 1, H, W) β€” cell probability logit
dist = self.dist_head(d1) # (B, 1, H, W) β€” distance transform
# Concatenate: (B, 4, H, W) = [dY, dX, prob, dist]
return torch.cat([flows, prob, dist], dim=1)
def predict(self, x):
"""Inference mode: returns dict of named outputs."""
with torch.no_grad():
out = self.forward(x)
return {
"dY": out[:, 0],
"dX": out[:, 1],
"cell_prob": torch.sigmoid(out[:, 2]),
"dist_transform": torch.relu(out[:, 3]), # distance is non-negative
}
@staticmethod
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)