| """ |
| Tiny U-Net for microbubble segmentation (student model). |
| |
| Architecture: Depthwise-separable U-Net with 4-channel output. |
| Inspired by PicoSAM2 (arxiv:2506.18807) β optimized for speed on narrow domains. |
| |
| Output channels: |
| [0] dY β vertical gradient flow (Cellpose-compatible) |
| [1] dX β horizontal gradient flow (Cellpose-compatible) |
| [2] cell_prob β foreground probability (sigmoid) |
| [3] dist_transform β distance transform (encodes bubble radius) |
| |
| Params: ~389K (base_ch=16), ~1.5M (base_ch=32) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class DepthwiseSeparableConv(nn.Module): |
| """Depthwise separable convolution: depthwise + pointwise. |
| ~8-9x fewer parameters than standard conv for same channel count.""" |
|
|
| def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, stride=1): |
| super().__init__() |
| self.depthwise = nn.Conv2d( |
| in_ch, in_ch, kernel_size, stride=stride, padding=padding, groups=in_ch, bias=False |
| ) |
| self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False) |
| self.bn = nn.BatchNorm2d(out_ch) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| return self.relu(self.bn(self.pointwise(self.depthwise(x)))) |
|
|
|
|
| class ConvBlock(nn.Module): |
| """Double convolution block with optional depthwise-separable convs.""" |
|
|
| def __init__(self, in_ch, out_ch, use_depthwise=True): |
| super().__init__() |
| if use_depthwise: |
| self.block = nn.Sequential( |
| DepthwiseSeparableConv(in_ch, out_ch), |
| DepthwiseSeparableConv(out_ch, out_ch), |
| ) |
| else: |
| self.block = nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), |
| nn.BatchNorm2d(out_ch), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class TinyBubbleNet(nn.Module): |
| """ |
| Tiny U-Net for microbubble instance segmentation. |
| |
| Args: |
| in_channels: Input image channels (1 for grayscale, 3 for RGB) |
| base_ch: Base channel count (16 β ~389K params, 32 β ~1.5M params) |
| out_channels: Output channels (default 4: dY, dX, cell_prob, dist_transform) |
| use_depthwise: Use depthwise-separable convs (True for speed, False for accuracy) |
| """ |
|
|
| def __init__(self, in_channels=1, base_ch=16, out_channels=4, use_depthwise=True): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| |
| self.enc1 = ConvBlock(in_channels, base_ch, use_depthwise=False) |
| self.enc2 = ConvBlock(base_ch, base_ch * 2, use_depthwise) |
| self.enc3 = ConvBlock(base_ch * 2, base_ch * 4, use_depthwise) |
| self.enc4 = ConvBlock(base_ch * 4, base_ch * 8, use_depthwise) |
|
|
| self.pool = nn.MaxPool2d(2) |
|
|
| |
| self.bottleneck = ConvBlock(base_ch * 8, base_ch * 16, use_depthwise) |
|
|
| |
| self.up4 = nn.ConvTranspose2d(base_ch * 16, base_ch * 8, kernel_size=2, stride=2) |
| self.dec4 = ConvBlock(base_ch * 16, base_ch * 8, use_depthwise) |
|
|
| self.up3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, kernel_size=2, stride=2) |
| self.dec3 = ConvBlock(base_ch * 8, base_ch * 4, use_depthwise) |
|
|
| self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, kernel_size=2, stride=2) |
| self.dec2 = ConvBlock(base_ch * 4, base_ch * 2, use_depthwise) |
|
|
| self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, kernel_size=2, stride=2) |
| self.dec1 = ConvBlock(base_ch * 2, base_ch, use_depthwise) |
|
|
| |
| self.flow_head = nn.Conv2d(base_ch, 2, 1) |
| self.prob_head = nn.Conv2d(base_ch, 1, 1) |
| self.dist_head = nn.Conv2d(base_ch, 1, 1) |
|
|
| def forward(self, x): |
| |
| e1 = self.enc1(x) |
| e2 = self.enc2(self.pool(e1)) |
| e3 = self.enc3(self.pool(e2)) |
| e4 = self.enc4(self.pool(e3)) |
|
|
| |
| b = self.bottleneck(self.pool(e4)) |
|
|
| |
| d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1)) |
| d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1)) |
| d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1)) |
| d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1)) |
|
|
| |
| flows = self.flow_head(d1) |
| prob = self.prob_head(d1) |
| dist = self.dist_head(d1) |
|
|
| |
| return torch.cat([flows, prob, dist], dim=1) |
|
|
| def predict(self, x): |
| """Inference mode: returns dict of named outputs.""" |
| with torch.no_grad(): |
| out = self.forward(x) |
| return { |
| "dY": out[:, 0], |
| "dX": out[:, 1], |
| "cell_prob": torch.sigmoid(out[:, 2]), |
| "dist_transform": torch.relu(out[:, 3]), |
| } |
|
|
| @staticmethod |
| def count_params(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |