""" CondensatePose Model Architecture ================================= EfficientNetV2 encoder with Feature Pyramid Network and Style Modulation for detecting biomolecular condensates in fluorescence microscopy images. Architecture Components: - Encoder: EfficientNetV2 (pretrained, adapted for grayscale) - Decoder: Multi-scale FPN with style-based feature modulation - Outputs: Binary mask + flow fields for instance segmentation Paper: [Add your paper link here] GitHub: [Add your GitHub link here] """ import torch import torch.nn as nn import torch.nn.functional as F from timm import create_model from typing import Dict class SparseAttention(nn.Module): """Spatial attention module for focusing on sparse condensate regions.""" def __init__(self, kernel_size=11): super().__init__() self.attention = nn.Sequential( nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False), nn.Sigmoid() ) def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) attention_input = torch.cat([avg_out, max_out], dim=1) attention_map = self.attention(attention_input) return x * attention_map class NormProjection(nn.Module): """Normalized 1x1 convolution for feature projection.""" def __init__(self, in_channels, out_channels): super().__init__() self.bn = nn.BatchNorm2d(in_channels) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) def forward(self, x): return self.conv(self.bn(x)) class ConvBlock(nn.Module): """Standard convolution block with normalization and activation.""" def __init__(self, in_channels, out_channels): super().__init__() self.bn = nn.BatchNorm2d(in_channels) self.swish = nn.SiLU(inplace=True) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False) def forward(self, x): return self.conv(self.swish(self.bn(x))) class StyleModulatedConv(nn.Module): """Convolution with style-based feature modulation.""" def __init__(self, channels, style_dim): super().__init__() self.conv_block = ConvBlock(channels, channels) self.style_projection = nn.Linear(style_dim, channels) def forward(self, x, style_vector): feat = self.conv_block(x) style_bias = self.style_projection(style_vector).unsqueeze(-1).unsqueeze(-1) return feat + style_bias class DualResidualBlock(nn.Module): """Two-stage residual block for deep feature fusion with style conditioning.""" def __init__(self, channels, style_dim): super().__init__() self.style_conv1 = StyleModulatedConv(channels, style_dim) self.style_conv2 = StyleModulatedConv(channels, style_dim) self.style_conv3 = StyleModulatedConv(channels, style_dim) self.projection = NormProjection(channels, channels) self.initial_conv = ConvBlock(channels, channels) def forward(self, x, lateral, style_vector): combined = self.initial_conv(x) + lateral x_intermediate = self.style_conv1(combined, style_vector) + self.projection(x) refined = self.style_conv2(x_intermediate, style_vector) output = self.style_conv3(refined, style_vector) + x_intermediate return output class MultiScaleEncoder(nn.Module): """EfficientNetV2-based encoder for multi-scale feature extraction.""" def __init__(self, variant='rw_s', pyramid_channels=[24, 48, 64, 160]): super().__init__() encoder_name = f"efficientnetv2_{variant}" self.base_encoder = create_model( encoder_name, features_only=True, pretrained=True, in_chans=1, out_indices=[0, 1, 2, 3] ) enc_channels = self.base_encoder.feature_info.channels() self.channel_adapters = nn.ModuleList([ nn.Sequential( nn.Conv2d(enc_ch, target_ch, kernel_size=1, bias=False), nn.BatchNorm2d(target_ch) ) for enc_ch, target_ch in zip(enc_channels[:4], pyramid_channels) ]) def forward(self, x): features = self.base_encoder(x)[:4] adapted_features = [] for feat, adapter in zip(features, self.channel_adapters): adapted_features.append(adapter(feat)) return adapted_features class CondensatePoseModel(nn.Module): """ CondensatePose: Multi-scale segmentation model for biomolecular condensates. Architecture: - Encoder: EfficientNetV2 with channel adaptation - Decoder: Feature Pyramid Network with: * Style-based feature modulation * Dual residual blocks for feature fusion * Multi-scale upsampling refinement - Attention: Spatial attention for sparse object detection - Outputs: Binary mask logits + flow field vectors Args: encoder_variant (str): EfficientNetV2 variant (default: 'rw_s') pyramid_channels (list): Channel dimensions for pyramid levels use_spatial_attention (bool): Enable spatial attention module spatial_kernel_size (int): Kernel size for spatial attention dropout_rate (float): Dropout rate in decoder """ def __init__( self, encoder_variant='rw_s', pyramid_channels=[24, 48, 64, 160], use_spatial_attention=True, spatial_kernel_size=11, dropout_rate=0.15 ): super().__init__() self.use_spatial_attention = use_spatial_attention self.pyramid_channels = pyramid_channels # Multi-scale encoder self.encoder = MultiScaleEncoder( variant=encoder_variant, pyramid_channels=pyramid_channels ) # Style vector dimension style_dim = pyramid_channels[-1] # Default: 160 pyramid_dim = 32 # Pyramid processing blocks self.pyramid_block4 = DualResidualBlock(pyramid_dim, style_dim) self.pyramid_block3 = DualResidualBlock(pyramid_dim, style_dim) self.pyramid_block2 = DualResidualBlock(pyramid_dim, style_dim) self.pyramid_block1 = DualResidualBlock(pyramid_dim, style_dim) # Channel reduction for all levels self.lateral_conv4 = nn.Conv2d(pyramid_channels[3], pyramid_dim, kernel_size=1, bias=False) self.lateral_conv3 = nn.Conv2d(pyramid_channels[2], pyramid_dim, kernel_size=1, bias=False) self.lateral_conv2 = nn.Conv2d(pyramid_channels[1], pyramid_dim, kernel_size=1, bias=False) self.lateral_conv1 = nn.Conv2d(pyramid_channels[0], pyramid_dim, kernel_size=1, bias=False) # Upsampling refinement blocks self.upsample_blocks = nn.ModuleList([ DualResidualBlock(pyramid_dim, style_dim) for _ in range(3) ]) # Final feature projection self.output_projection = NormProjection(pyramid_dim, pyramid_dim) # Spatial attention if use_spatial_attention: self.spatial_attention = SparseAttention(spatial_kernel_size) # Dropout self.dropout = nn.Dropout2d(dropout_rate) # Output heads self.mask_head = nn.Conv2d(pyramid_dim, 1, kernel_size=1) # Binary segmentation self.flow_head = nn.Conv2d(pyramid_dim, 2, kernel_size=1) # Flow vectors self._init_weights() def _init_weights(self): """Initialize weights for sparse condensate segmentation.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) # Mask head initialization nn.init.xavier_uniform_(self.mask_head.weight, gain=0.1) if self.mask_head.bias is not None: nn.init.constant_(self.mask_head.bias, -2.0) # Bias toward background # Flow head: initialize to zero nn.init.zeros_(self.flow_head.weight) if self.flow_head.bias is not None: nn.init.zeros_(self.flow_head.bias) def upsample_and_refine(self, x, style_vector, block): """Upsample features and apply refinement block.""" x_up = F.interpolate(x, scale_factor=2, mode='nearest') return block(x_up, torch.zeros_like(x_up), style_vector) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass. Args: x: Input tensor of shape (B, 1, H, W) - grayscale microscopy image Returns: Dictionary with keys: - 'mask': Shape (B, 1, H, W) - binary mask logits - 'flows': Shape (B, 2, H, W) - flow field vectors (dy, dx) """ # Input validation if torch.isnan(x).any() or torch.isinf(x).any(): raise ValueError("Input contains NaN or Inf values") # Normalize input x_mean = x.mean(dim=[2, 3], keepdim=True) x_std = x.std(dim=[2, 3], keepdim=True) + 1e-8 x_normalized = (x - x_mean) / x_std # Extract multi-scale features features = self.encoder(x_normalized) C1, C2, C3, C4 = features # Compute global style vector style_vector = F.adaptive_avg_pool2d(C4, 1).squeeze(-1).squeeze(-1) style_vector = F.normalize(style_vector, p=2, dim=1) # Build feature pyramid C4_reduced = self.lateral_conv4(C4) P4 = self.pyramid_block4(C4_reduced, C4_reduced, style_vector) P4_up = F.interpolate(P4, size=C3.shape[2:], mode='nearest') C3_reduced = self.lateral_conv3(C3) P3 = self.pyramid_block3(P4_up, C3_reduced, style_vector) P3_up = F.interpolate(P3, size=C2.shape[2:], mode='nearest') C2_reduced = self.lateral_conv2(C2) P2 = self.pyramid_block2(P3_up, C2_reduced, style_vector) P2_up = F.interpolate(P2, size=C1.shape[2:], mode='nearest') C1_reduced = self.lateral_conv1(C1) P1 = self.pyramid_block1(P2_up, C1_reduced, style_vector) # Multi-scale upsampling refinement P4_refined = self.upsample_and_refine(P4, style_vector, self.upsample_blocks[0]) P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[1]) P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[2]) P3_refined = self.upsample_and_refine(P3, style_vector, self.upsample_blocks[0]) P3_refined = self.upsample_and_refine(P3_refined, style_vector, self.upsample_blocks[1]) P2_refined = self.upsample_and_refine(P2, style_vector, self.upsample_blocks[0]) # Combine all scales combined = P1 + P2_refined + P3_refined + P4_refined features_final = self.output_projection(combined) # Apply spatial attention if self.use_spatial_attention: features_final = self.spatial_attention(features_final) # Apply dropout features_final = self.dropout(features_final) # Generate outputs mask_logits = self.mask_head(features_final) flow_vectors = self.flow_head(features_final) # Upsample to match input size target_size = x.shape[2:] mask_logits = F.interpolate(mask_logits, size=target_size, mode='bilinear', align_corners=False) flow_vectors = F.interpolate(flow_vectors, size=target_size, mode='bilinear', align_corners=False) return { 'mask': mask_logits, 'flows': flow_vectors } def load_condensatepose_model( checkpoint_path: str, device: str = 'cuda' ) -> CondensatePoseModel: """ Load a trained CondensatePose model from checkpoint. Args: checkpoint_path: Path to model checkpoint (.pth file) device: Device to load model on ('cuda' or 'cpu') Returns: Loaded model in eval mode Example: >>> model = load_condensatepose_model('model_weights.pth', device='cuda') >>> model.eval() >>> >>> # Run inference >>> with torch.no_grad(): >>> outputs = model(image_tensor) >>> mask_logits = outputs['mask'] >>> flows = outputs['flows'] """ checkpoint = torch.load(checkpoint_path, map_location=device) config = checkpoint.get('model_config', {}) model = CondensatePoseModel( encoder_variant=config.get('encoder_variant', 'rw_s'), pyramid_channels=config.get('pyramid_channels', [24, 48, 64, 160]), use_spatial_attention=config.get('use_spatial_attention', True), spatial_kernel_size=config.get('spatial_kernel_size', 11), dropout_rate=config.get('dropout_rate', 0.15), ) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() return model # Alias for compatibility CondensateSegmentationNet = CondensatePoseModel