|
|
""" |
|
|
CondensatePose Model Architecture |
|
|
================================= |
|
|
|
|
|
EfficientNetV2 encoder with Feature Pyramid Network and Style Modulation |
|
|
for detecting biomolecular condensates in fluorescence microscopy images. |
|
|
|
|
|
Architecture Components: |
|
|
- Encoder: EfficientNetV2 (pretrained, adapted for grayscale) |
|
|
- Decoder: Multi-scale FPN with style-based feature modulation |
|
|
- Outputs: Binary mask + flow fields for instance segmentation |
|
|
|
|
|
Paper: [Add your paper link here] |
|
|
GitHub: [Add your GitHub link here] |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from timm import create_model |
|
|
from typing import Dict |
|
|
|
|
|
class SparseAttention(nn.Module): |
|
|
"""Spatial attention module for focusing on sparse condensate regions.""" |
|
|
|
|
|
def __init__(self, kernel_size=11): |
|
|
super().__init__() |
|
|
self.attention = nn.Sequential( |
|
|
nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
avg_out = torch.mean(x, dim=1, keepdim=True) |
|
|
max_out, _ = torch.max(x, dim=1, keepdim=True) |
|
|
attention_input = torch.cat([avg_out, max_out], dim=1) |
|
|
attention_map = self.attention(attention_input) |
|
|
return x * attention_map |
|
|
|
|
|
|
|
|
class NormProjection(nn.Module): |
|
|
"""Normalized 1x1 convolution for feature projection.""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.bn = nn.BatchNorm2d(in_channels) |
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(self.bn(x)) |
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
"""Standard convolution block with normalization and activation.""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.bn = nn.BatchNorm2d(in_channels) |
|
|
self.swish = nn.SiLU(inplace=True) |
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(self.swish(self.bn(x))) |
|
|
|
|
|
|
|
|
class StyleModulatedConv(nn.Module): |
|
|
"""Convolution with style-based feature modulation.""" |
|
|
|
|
|
def __init__(self, channels, style_dim): |
|
|
super().__init__() |
|
|
self.conv_block = ConvBlock(channels, channels) |
|
|
self.style_projection = nn.Linear(style_dim, channels) |
|
|
|
|
|
def forward(self, x, style_vector): |
|
|
feat = self.conv_block(x) |
|
|
style_bias = self.style_projection(style_vector).unsqueeze(-1).unsqueeze(-1) |
|
|
return feat + style_bias |
|
|
|
|
|
|
|
|
class DualResidualBlock(nn.Module): |
|
|
"""Two-stage residual block for deep feature fusion with style conditioning.""" |
|
|
|
|
|
def __init__(self, channels, style_dim): |
|
|
super().__init__() |
|
|
self.style_conv1 = StyleModulatedConv(channels, style_dim) |
|
|
self.style_conv2 = StyleModulatedConv(channels, style_dim) |
|
|
self.style_conv3 = StyleModulatedConv(channels, style_dim) |
|
|
self.projection = NormProjection(channels, channels) |
|
|
self.initial_conv = ConvBlock(channels, channels) |
|
|
|
|
|
def forward(self, x, lateral, style_vector): |
|
|
combined = self.initial_conv(x) + lateral |
|
|
x_intermediate = self.style_conv1(combined, style_vector) + self.projection(x) |
|
|
refined = self.style_conv2(x_intermediate, style_vector) |
|
|
output = self.style_conv3(refined, style_vector) + x_intermediate |
|
|
return output |
|
|
|
|
|
|
|
|
class MultiScaleEncoder(nn.Module): |
|
|
"""EfficientNetV2-based encoder for multi-scale feature extraction.""" |
|
|
|
|
|
def __init__(self, variant='rw_s', pyramid_channels=[24, 48, 64, 160]): |
|
|
super().__init__() |
|
|
|
|
|
encoder_name = f"efficientnetv2_{variant}" |
|
|
self.base_encoder = create_model( |
|
|
encoder_name, |
|
|
features_only=True, |
|
|
pretrained=True, |
|
|
in_chans=1, |
|
|
out_indices=[0, 1, 2, 3] |
|
|
) |
|
|
|
|
|
enc_channels = self.base_encoder.feature_info.channels() |
|
|
|
|
|
self.channel_adapters = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Conv2d(enc_ch, target_ch, kernel_size=1, bias=False), |
|
|
nn.BatchNorm2d(target_ch) |
|
|
) |
|
|
for enc_ch, target_ch in zip(enc_channels[:4], pyramid_channels) |
|
|
]) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.base_encoder(x)[:4] |
|
|
adapted_features = [] |
|
|
for feat, adapter in zip(features, self.channel_adapters): |
|
|
adapted_features.append(adapter(feat)) |
|
|
return adapted_features |
|
|
|
|
|
|
|
|
class CondensatePoseModel(nn.Module): |
|
|
""" |
|
|
CondensatePose: Multi-scale segmentation model for biomolecular condensates. |
|
|
|
|
|
Architecture: |
|
|
- Encoder: EfficientNetV2 with channel adaptation |
|
|
- Decoder: Feature Pyramid Network with: |
|
|
* Style-based feature modulation |
|
|
* Dual residual blocks for feature fusion |
|
|
* Multi-scale upsampling refinement |
|
|
- Attention: Spatial attention for sparse object detection |
|
|
- Outputs: Binary mask logits + flow field vectors |
|
|
|
|
|
Args: |
|
|
encoder_variant (str): EfficientNetV2 variant (default: 'rw_s') |
|
|
pyramid_channels (list): Channel dimensions for pyramid levels |
|
|
use_spatial_attention (bool): Enable spatial attention module |
|
|
spatial_kernel_size (int): Kernel size for spatial attention |
|
|
dropout_rate (float): Dropout rate in decoder |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder_variant='rw_s', |
|
|
pyramid_channels=[24, 48, 64, 160], |
|
|
use_spatial_attention=True, |
|
|
spatial_kernel_size=11, |
|
|
dropout_rate=0.15 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.use_spatial_attention = use_spatial_attention |
|
|
self.pyramid_channels = pyramid_channels |
|
|
|
|
|
|
|
|
self.encoder = MultiScaleEncoder( |
|
|
variant=encoder_variant, |
|
|
pyramid_channels=pyramid_channels |
|
|
) |
|
|
|
|
|
|
|
|
style_dim = pyramid_channels[-1] |
|
|
pyramid_dim = 32 |
|
|
|
|
|
|
|
|
self.pyramid_block4 = DualResidualBlock(pyramid_dim, style_dim) |
|
|
self.pyramid_block3 = DualResidualBlock(pyramid_dim, style_dim) |
|
|
self.pyramid_block2 = DualResidualBlock(pyramid_dim, style_dim) |
|
|
self.pyramid_block1 = DualResidualBlock(pyramid_dim, style_dim) |
|
|
|
|
|
|
|
|
self.lateral_conv4 = nn.Conv2d(pyramid_channels[3], pyramid_dim, kernel_size=1, bias=False) |
|
|
self.lateral_conv3 = nn.Conv2d(pyramid_channels[2], pyramid_dim, kernel_size=1, bias=False) |
|
|
self.lateral_conv2 = nn.Conv2d(pyramid_channels[1], pyramid_dim, kernel_size=1, bias=False) |
|
|
self.lateral_conv1 = nn.Conv2d(pyramid_channels[0], pyramid_dim, kernel_size=1, bias=False) |
|
|
|
|
|
|
|
|
self.upsample_blocks = nn.ModuleList([ |
|
|
DualResidualBlock(pyramid_dim, style_dim) for _ in range(3) |
|
|
]) |
|
|
|
|
|
|
|
|
self.output_projection = NormProjection(pyramid_dim, pyramid_dim) |
|
|
|
|
|
|
|
|
if use_spatial_attention: |
|
|
self.spatial_attention = SparseAttention(spatial_kernel_size) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout2d(dropout_rate) |
|
|
|
|
|
|
|
|
self.mask_head = nn.Conv2d(pyramid_dim, 1, kernel_size=1) |
|
|
self.flow_head = nn.Conv2d(pyramid_dim, 2, kernel_size=1) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights for sparse condensate segmentation.""" |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.xavier_normal_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.mask_head.weight, gain=0.1) |
|
|
if self.mask_head.bias is not None: |
|
|
nn.init.constant_(self.mask_head.bias, -2.0) |
|
|
|
|
|
|
|
|
nn.init.zeros_(self.flow_head.weight) |
|
|
if self.flow_head.bias is not None: |
|
|
nn.init.zeros_(self.flow_head.bias) |
|
|
|
|
|
def upsample_and_refine(self, x, style_vector, block): |
|
|
"""Upsample features and apply refinement block.""" |
|
|
x_up = F.interpolate(x, scale_factor=2, mode='nearest') |
|
|
return block(x_up, torch.zeros_like(x_up), style_vector) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (B, 1, H, W) - grayscale microscopy image |
|
|
|
|
|
Returns: |
|
|
Dictionary with keys: |
|
|
- 'mask': Shape (B, 1, H, W) - binary mask logits |
|
|
- 'flows': Shape (B, 2, H, W) - flow field vectors (dy, dx) |
|
|
""" |
|
|
|
|
|
if torch.isnan(x).any() or torch.isinf(x).any(): |
|
|
raise ValueError("Input contains NaN or Inf values") |
|
|
|
|
|
|
|
|
x_mean = x.mean(dim=[2, 3], keepdim=True) |
|
|
x_std = x.std(dim=[2, 3], keepdim=True) + 1e-8 |
|
|
x_normalized = (x - x_mean) / x_std |
|
|
|
|
|
|
|
|
features = self.encoder(x_normalized) |
|
|
C1, C2, C3, C4 = features |
|
|
|
|
|
|
|
|
style_vector = F.adaptive_avg_pool2d(C4, 1).squeeze(-1).squeeze(-1) |
|
|
style_vector = F.normalize(style_vector, p=2, dim=1) |
|
|
|
|
|
|
|
|
C4_reduced = self.lateral_conv4(C4) |
|
|
P4 = self.pyramid_block4(C4_reduced, C4_reduced, style_vector) |
|
|
|
|
|
P4_up = F.interpolate(P4, size=C3.shape[2:], mode='nearest') |
|
|
C3_reduced = self.lateral_conv3(C3) |
|
|
P3 = self.pyramid_block3(P4_up, C3_reduced, style_vector) |
|
|
|
|
|
P3_up = F.interpolate(P3, size=C2.shape[2:], mode='nearest') |
|
|
C2_reduced = self.lateral_conv2(C2) |
|
|
P2 = self.pyramid_block2(P3_up, C2_reduced, style_vector) |
|
|
|
|
|
P2_up = F.interpolate(P2, size=C1.shape[2:], mode='nearest') |
|
|
C1_reduced = self.lateral_conv1(C1) |
|
|
P1 = self.pyramid_block1(P2_up, C1_reduced, style_vector) |
|
|
|
|
|
|
|
|
P4_refined = self.upsample_and_refine(P4, style_vector, self.upsample_blocks[0]) |
|
|
P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[1]) |
|
|
P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[2]) |
|
|
|
|
|
P3_refined = self.upsample_and_refine(P3, style_vector, self.upsample_blocks[0]) |
|
|
P3_refined = self.upsample_and_refine(P3_refined, style_vector, self.upsample_blocks[1]) |
|
|
|
|
|
P2_refined = self.upsample_and_refine(P2, style_vector, self.upsample_blocks[0]) |
|
|
|
|
|
|
|
|
combined = P1 + P2_refined + P3_refined + P4_refined |
|
|
features_final = self.output_projection(combined) |
|
|
|
|
|
|
|
|
if self.use_spatial_attention: |
|
|
features_final = self.spatial_attention(features_final) |
|
|
|
|
|
|
|
|
features_final = self.dropout(features_final) |
|
|
|
|
|
|
|
|
mask_logits = self.mask_head(features_final) |
|
|
flow_vectors = self.flow_head(features_final) |
|
|
|
|
|
|
|
|
target_size = x.shape[2:] |
|
|
mask_logits = F.interpolate(mask_logits, size=target_size, mode='bilinear', align_corners=False) |
|
|
flow_vectors = F.interpolate(flow_vectors, size=target_size, mode='bilinear', align_corners=False) |
|
|
|
|
|
return { |
|
|
'mask': mask_logits, |
|
|
'flows': flow_vectors |
|
|
} |
|
|
|
|
|
|
|
|
def load_condensatepose_model( |
|
|
checkpoint_path: str, |
|
|
device: str = 'cuda' |
|
|
) -> CondensatePoseModel: |
|
|
""" |
|
|
Load a trained CondensatePose model from checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to model checkpoint (.pth file) |
|
|
device: Device to load model on ('cuda' or 'cpu') |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
|
|
|
Example: |
|
|
>>> model = load_condensatepose_model('model_weights.pth', device='cuda') |
|
|
>>> model.eval() |
|
|
>>> |
|
|
>>> # Run inference |
|
|
>>> with torch.no_grad(): |
|
|
>>> outputs = model(image_tensor) |
|
|
>>> mask_logits = outputs['mask'] |
|
|
>>> flows = outputs['flows'] |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
config = checkpoint.get('model_config', {}) |
|
|
|
|
|
model = CondensatePoseModel( |
|
|
encoder_variant=config.get('encoder_variant', 'rw_s'), |
|
|
pyramid_channels=config.get('pyramid_channels', [24, 48, 64, 160]), |
|
|
use_spatial_attention=config.get('use_spatial_attention', True), |
|
|
spatial_kernel_size=config.get('spatial_kernel_size', 11), |
|
|
dropout_rate=config.get('dropout_rate', 0.15), |
|
|
) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
CondensateSegmentationNet = CondensatePoseModel |
|
|
|