ebasone's picture
Upload folder using huggingface_hub
90038de verified
"""
CondensatePose Model Architecture
=================================
EfficientNetV2 encoder with Feature Pyramid Network and Style Modulation
for detecting biomolecular condensates in fluorescence microscopy images.
Architecture Components:
- Encoder: EfficientNetV2 (pretrained, adapted for grayscale)
- Decoder: Multi-scale FPN with style-based feature modulation
- Outputs: Binary mask + flow fields for instance segmentation
Paper: [Add your paper link here]
GitHub: [Add your GitHub link here]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from typing import Dict
class SparseAttention(nn.Module):
"""Spatial attention module for focusing on sparse condensate regions."""
def __init__(self, kernel_size=11):
super().__init__()
self.attention = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
nn.Sigmoid()
)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
attention_input = torch.cat([avg_out, max_out], dim=1)
attention_map = self.attention(attention_input)
return x * attention_map
class NormProjection(nn.Module):
"""Normalized 1x1 convolution for feature projection."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
def forward(self, x):
return self.conv(self.bn(x))
class ConvBlock(nn.Module):
"""Standard convolution block with normalization and activation."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.swish = nn.SiLU(inplace=True)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
def forward(self, x):
return self.conv(self.swish(self.bn(x)))
class StyleModulatedConv(nn.Module):
"""Convolution with style-based feature modulation."""
def __init__(self, channels, style_dim):
super().__init__()
self.conv_block = ConvBlock(channels, channels)
self.style_projection = nn.Linear(style_dim, channels)
def forward(self, x, style_vector):
feat = self.conv_block(x)
style_bias = self.style_projection(style_vector).unsqueeze(-1).unsqueeze(-1)
return feat + style_bias
class DualResidualBlock(nn.Module):
"""Two-stage residual block for deep feature fusion with style conditioning."""
def __init__(self, channels, style_dim):
super().__init__()
self.style_conv1 = StyleModulatedConv(channels, style_dim)
self.style_conv2 = StyleModulatedConv(channels, style_dim)
self.style_conv3 = StyleModulatedConv(channels, style_dim)
self.projection = NormProjection(channels, channels)
self.initial_conv = ConvBlock(channels, channels)
def forward(self, x, lateral, style_vector):
combined = self.initial_conv(x) + lateral
x_intermediate = self.style_conv1(combined, style_vector) + self.projection(x)
refined = self.style_conv2(x_intermediate, style_vector)
output = self.style_conv3(refined, style_vector) + x_intermediate
return output
class MultiScaleEncoder(nn.Module):
"""EfficientNetV2-based encoder for multi-scale feature extraction."""
def __init__(self, variant='rw_s', pyramid_channels=[24, 48, 64, 160]):
super().__init__()
encoder_name = f"efficientnetv2_{variant}"
self.base_encoder = create_model(
encoder_name,
features_only=True,
pretrained=True,
in_chans=1,
out_indices=[0, 1, 2, 3]
)
enc_channels = self.base_encoder.feature_info.channels()
self.channel_adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(enc_ch, target_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(target_ch)
)
for enc_ch, target_ch in zip(enc_channels[:4], pyramid_channels)
])
def forward(self, x):
features = self.base_encoder(x)[:4]
adapted_features = []
for feat, adapter in zip(features, self.channel_adapters):
adapted_features.append(adapter(feat))
return adapted_features
class CondensatePoseModel(nn.Module):
"""
CondensatePose: Multi-scale segmentation model for biomolecular condensates.
Architecture:
- Encoder: EfficientNetV2 with channel adaptation
- Decoder: Feature Pyramid Network with:
* Style-based feature modulation
* Dual residual blocks for feature fusion
* Multi-scale upsampling refinement
- Attention: Spatial attention for sparse object detection
- Outputs: Binary mask logits + flow field vectors
Args:
encoder_variant (str): EfficientNetV2 variant (default: 'rw_s')
pyramid_channels (list): Channel dimensions for pyramid levels
use_spatial_attention (bool): Enable spatial attention module
spatial_kernel_size (int): Kernel size for spatial attention
dropout_rate (float): Dropout rate in decoder
"""
def __init__(
self,
encoder_variant='rw_s',
pyramid_channels=[24, 48, 64, 160],
use_spatial_attention=True,
spatial_kernel_size=11,
dropout_rate=0.15
):
super().__init__()
self.use_spatial_attention = use_spatial_attention
self.pyramid_channels = pyramid_channels
# Multi-scale encoder
self.encoder = MultiScaleEncoder(
variant=encoder_variant,
pyramid_channels=pyramid_channels
)
# Style vector dimension
style_dim = pyramid_channels[-1] # Default: 160
pyramid_dim = 32
# Pyramid processing blocks
self.pyramid_block4 = DualResidualBlock(pyramid_dim, style_dim)
self.pyramid_block3 = DualResidualBlock(pyramid_dim, style_dim)
self.pyramid_block2 = DualResidualBlock(pyramid_dim, style_dim)
self.pyramid_block1 = DualResidualBlock(pyramid_dim, style_dim)
# Channel reduction for all levels
self.lateral_conv4 = nn.Conv2d(pyramid_channels[3], pyramid_dim, kernel_size=1, bias=False)
self.lateral_conv3 = nn.Conv2d(pyramid_channels[2], pyramid_dim, kernel_size=1, bias=False)
self.lateral_conv2 = nn.Conv2d(pyramid_channels[1], pyramid_dim, kernel_size=1, bias=False)
self.lateral_conv1 = nn.Conv2d(pyramid_channels[0], pyramid_dim, kernel_size=1, bias=False)
# Upsampling refinement blocks
self.upsample_blocks = nn.ModuleList([
DualResidualBlock(pyramid_dim, style_dim) for _ in range(3)
])
# Final feature projection
self.output_projection = NormProjection(pyramid_dim, pyramid_dim)
# Spatial attention
if use_spatial_attention:
self.spatial_attention = SparseAttention(spatial_kernel_size)
# Dropout
self.dropout = nn.Dropout2d(dropout_rate)
# Output heads
self.mask_head = nn.Conv2d(pyramid_dim, 1, kernel_size=1) # Binary segmentation
self.flow_head = nn.Conv2d(pyramid_dim, 2, kernel_size=1) # Flow vectors
self._init_weights()
def _init_weights(self):
"""Initialize weights for sparse condensate segmentation."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# Mask head initialization
nn.init.xavier_uniform_(self.mask_head.weight, gain=0.1)
if self.mask_head.bias is not None:
nn.init.constant_(self.mask_head.bias, -2.0) # Bias toward background
# Flow head: initialize to zero
nn.init.zeros_(self.flow_head.weight)
if self.flow_head.bias is not None:
nn.init.zeros_(self.flow_head.bias)
def upsample_and_refine(self, x, style_vector, block):
"""Upsample features and apply refinement block."""
x_up = F.interpolate(x, scale_factor=2, mode='nearest')
return block(x_up, torch.zeros_like(x_up), style_vector)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass.
Args:
x: Input tensor of shape (B, 1, H, W) - grayscale microscopy image
Returns:
Dictionary with keys:
- 'mask': Shape (B, 1, H, W) - binary mask logits
- 'flows': Shape (B, 2, H, W) - flow field vectors (dy, dx)
"""
# Input validation
if torch.isnan(x).any() or torch.isinf(x).any():
raise ValueError("Input contains NaN or Inf values")
# Normalize input
x_mean = x.mean(dim=[2, 3], keepdim=True)
x_std = x.std(dim=[2, 3], keepdim=True) + 1e-8
x_normalized = (x - x_mean) / x_std
# Extract multi-scale features
features = self.encoder(x_normalized)
C1, C2, C3, C4 = features
# Compute global style vector
style_vector = F.adaptive_avg_pool2d(C4, 1).squeeze(-1).squeeze(-1)
style_vector = F.normalize(style_vector, p=2, dim=1)
# Build feature pyramid
C4_reduced = self.lateral_conv4(C4)
P4 = self.pyramid_block4(C4_reduced, C4_reduced, style_vector)
P4_up = F.interpolate(P4, size=C3.shape[2:], mode='nearest')
C3_reduced = self.lateral_conv3(C3)
P3 = self.pyramid_block3(P4_up, C3_reduced, style_vector)
P3_up = F.interpolate(P3, size=C2.shape[2:], mode='nearest')
C2_reduced = self.lateral_conv2(C2)
P2 = self.pyramid_block2(P3_up, C2_reduced, style_vector)
P2_up = F.interpolate(P2, size=C1.shape[2:], mode='nearest')
C1_reduced = self.lateral_conv1(C1)
P1 = self.pyramid_block1(P2_up, C1_reduced, style_vector)
# Multi-scale upsampling refinement
P4_refined = self.upsample_and_refine(P4, style_vector, self.upsample_blocks[0])
P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[1])
P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[2])
P3_refined = self.upsample_and_refine(P3, style_vector, self.upsample_blocks[0])
P3_refined = self.upsample_and_refine(P3_refined, style_vector, self.upsample_blocks[1])
P2_refined = self.upsample_and_refine(P2, style_vector, self.upsample_blocks[0])
# Combine all scales
combined = P1 + P2_refined + P3_refined + P4_refined
features_final = self.output_projection(combined)
# Apply spatial attention
if self.use_spatial_attention:
features_final = self.spatial_attention(features_final)
# Apply dropout
features_final = self.dropout(features_final)
# Generate outputs
mask_logits = self.mask_head(features_final)
flow_vectors = self.flow_head(features_final)
# Upsample to match input size
target_size = x.shape[2:]
mask_logits = F.interpolate(mask_logits, size=target_size, mode='bilinear', align_corners=False)
flow_vectors = F.interpolate(flow_vectors, size=target_size, mode='bilinear', align_corners=False)
return {
'mask': mask_logits,
'flows': flow_vectors
}
def load_condensatepose_model(
checkpoint_path: str,
device: str = 'cuda'
) -> CondensatePoseModel:
"""
Load a trained CondensatePose model from checkpoint.
Args:
checkpoint_path: Path to model checkpoint (.pth file)
device: Device to load model on ('cuda' or 'cpu')
Returns:
Loaded model in eval mode
Example:
>>> model = load_condensatepose_model('model_weights.pth', device='cuda')
>>> model.eval()
>>>
>>> # Run inference
>>> with torch.no_grad():
>>> outputs = model(image_tensor)
>>> mask_logits = outputs['mask']
>>> flows = outputs['flows']
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint.get('model_config', {})
model = CondensatePoseModel(
encoder_variant=config.get('encoder_variant', 'rw_s'),
pyramid_channels=config.get('pyramid_channels', [24, 48, 64, 160]),
use_spatial_attention=config.get('use_spatial_attention', True),
spatial_kernel_size=config.get('spatial_kernel_size', 11),
dropout_rate=config.get('dropout_rate', 0.15),
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model
# Alias for compatibility
CondensateSegmentationNet = CondensatePoseModel