""" Tiny U-Net for microbubble segmentation (student model). Architecture: Depthwise-separable U-Net with 4-channel output. Inspired by PicoSAM2 (arxiv:2506.18807) — optimized for speed on narrow domains. Output channels: [0] dY — vertical gradient flow (Cellpose-compatible) [1] dX — horizontal gradient flow (Cellpose-compatible) [2] cell_prob — foreground probability (sigmoid) [3] dist_transform — distance transform (encodes bubble radius) Params: ~389K (base_ch=16), ~1.5M (base_ch=32) """ import torch import torch.nn as nn class DepthwiseSeparableConv(nn.Module): """Depthwise separable convolution: depthwise + pointwise. ~8-9x fewer parameters than standard conv for same channel count.""" def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, stride=1): super().__init__() self.depthwise = nn.Conv2d( in_ch, in_ch, kernel_size, stride=stride, padding=padding, groups=in_ch, bias=False ) self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False) self.bn = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.pointwise(self.depthwise(x)))) class ConvBlock(nn.Module): """Double convolution block with optional depthwise-separable convs.""" def __init__(self, in_ch, out_ch, use_depthwise=True): super().__init__() if use_depthwise: self.block = nn.Sequential( DepthwiseSeparableConv(in_ch, out_ch), DepthwiseSeparableConv(out_ch, out_ch), ) else: self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) class TinyBubbleNet(nn.Module): """ Tiny U-Net for microbubble instance segmentation. Args: in_channels: Input image channels (1 for grayscale, 3 for RGB) base_ch: Base channel count (16 → ~389K params, 32 → ~1.5M params) out_channels: Output channels (default 4: dY, dX, cell_prob, dist_transform) use_depthwise: Use depthwise-separable convs (True for speed, False for accuracy) """ def __init__(self, in_channels=1, base_ch=16, out_channels=4, use_depthwise=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels # Encoder self.enc1 = ConvBlock(in_channels, base_ch, use_depthwise=False) # first block: standard conv self.enc2 = ConvBlock(base_ch, base_ch * 2, use_depthwise) self.enc3 = ConvBlock(base_ch * 2, base_ch * 4, use_depthwise) self.enc4 = ConvBlock(base_ch * 4, base_ch * 8, use_depthwise) self.pool = nn.MaxPool2d(2) # Bottleneck self.bottleneck = ConvBlock(base_ch * 8, base_ch * 16, use_depthwise) # Decoder self.up4 = nn.ConvTranspose2d(base_ch * 16, base_ch * 8, kernel_size=2, stride=2) self.dec4 = ConvBlock(base_ch * 16, base_ch * 8, use_depthwise) self.up3 = nn.ConvTranspose2d(base_ch * 8, base_ch * 4, kernel_size=2, stride=2) self.dec3 = ConvBlock(base_ch * 8, base_ch * 4, use_depthwise) self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, kernel_size=2, stride=2) self.dec2 = ConvBlock(base_ch * 4, base_ch * 2, use_depthwise) self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, kernel_size=2, stride=2) self.dec1 = ConvBlock(base_ch * 2, base_ch, use_depthwise) # Output heads (separate for interpretability) self.flow_head = nn.Conv2d(base_ch, 2, 1) # dY, dX self.prob_head = nn.Conv2d(base_ch, 1, 1) # cell probability self.dist_head = nn.Conv2d(base_ch, 1, 1) # distance transform def forward(self, x): # Encoder path e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) e4 = self.enc4(self.pool(e3)) # Bottleneck b = self.bottleneck(self.pool(e4)) # Decoder path with skip connections d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1)) d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1)) d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1)) d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1)) # Output heads flows = self.flow_head(d1) # (B, 2, H, W) — dY, dX prob = self.prob_head(d1) # (B, 1, H, W) — cell probability logit dist = self.dist_head(d1) # (B, 1, H, W) — distance transform # Concatenate: (B, 4, H, W) = [dY, dX, prob, dist] return torch.cat([flows, prob, dist], dim=1) def predict(self, x): """Inference mode: returns dict of named outputs.""" with torch.no_grad(): out = self.forward(x) return { "dY": out[:, 0], "dX": out[:, 1], "cell_prob": torch.sigmoid(out[:, 2]), "dist_transform": torch.relu(out[:, 3]), # distance is non-negative } @staticmethod def count_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)