File size: 13,738 Bytes
90038de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
CondensatePose Model Architecture
=================================

EfficientNetV2 encoder with Feature Pyramid Network and Style Modulation
for detecting biomolecular condensates in fluorescence microscopy images.

Architecture Components:
- Encoder: EfficientNetV2 (pretrained, adapted for grayscale)
- Decoder: Multi-scale FPN with style-based feature modulation
- Outputs: Binary mask + flow fields for instance segmentation

Paper: [Add your paper link here]
GitHub: [Add your GitHub link here]
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from typing import Dict

class SparseAttention(nn.Module):
    """Spatial attention module for focusing on sparse condensate regions."""
    
    def __init__(self, kernel_size=11):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        attention_input = torch.cat([avg_out, max_out], dim=1)
        attention_map = self.attention(attention_input)
        return x * attention_map


class NormProjection(nn.Module):
    """Normalized 1x1 convolution for feature projection."""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        return self.conv(self.bn(x))


class ConvBlock(nn.Module):
    """Standard convolution block with normalization and activation."""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.swish = nn.SiLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.conv(self.swish(self.bn(x)))


class StyleModulatedConv(nn.Module):
    """Convolution with style-based feature modulation."""
    
    def __init__(self, channels, style_dim):
        super().__init__()
        self.conv_block = ConvBlock(channels, channels)
        self.style_projection = nn.Linear(style_dim, channels)

    def forward(self, x, style_vector):
        feat = self.conv_block(x)
        style_bias = self.style_projection(style_vector).unsqueeze(-1).unsqueeze(-1)
        return feat + style_bias


class DualResidualBlock(nn.Module):
    """Two-stage residual block for deep feature fusion with style conditioning."""
    
    def __init__(self, channels, style_dim):
        super().__init__()
        self.style_conv1 = StyleModulatedConv(channels, style_dim)
        self.style_conv2 = StyleModulatedConv(channels, style_dim)
        self.style_conv3 = StyleModulatedConv(channels, style_dim)
        self.projection = NormProjection(channels, channels)
        self.initial_conv = ConvBlock(channels, channels)

    def forward(self, x, lateral, style_vector):
        combined = self.initial_conv(x) + lateral
        x_intermediate = self.style_conv1(combined, style_vector) + self.projection(x)
        refined = self.style_conv2(x_intermediate, style_vector)
        output = self.style_conv3(refined, style_vector) + x_intermediate
        return output


class MultiScaleEncoder(nn.Module):
    """EfficientNetV2-based encoder for multi-scale feature extraction."""
    
    def __init__(self, variant='rw_s', pyramid_channels=[24, 48, 64, 160]):
        super().__init__()
        
        encoder_name = f"efficientnetv2_{variant}"
        self.base_encoder = create_model(
            encoder_name,
            features_only=True,
            pretrained=True,
            in_chans=1,
            out_indices=[0, 1, 2, 3]
        )
        
        enc_channels = self.base_encoder.feature_info.channels()
        
        self.channel_adapters = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(enc_ch, target_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(target_ch)
            )
            for enc_ch, target_ch in zip(enc_channels[:4], pyramid_channels)
        ])
    
    def forward(self, x):
        features = self.base_encoder(x)[:4]
        adapted_features = []
        for feat, adapter in zip(features, self.channel_adapters):
            adapted_features.append(adapter(feat))
        return adapted_features


class CondensatePoseModel(nn.Module):
    """
    CondensatePose: Multi-scale segmentation model for biomolecular condensates.
    
    Architecture:
        - Encoder: EfficientNetV2 with channel adaptation
        - Decoder: Feature Pyramid Network with:
            * Style-based feature modulation
            * Dual residual blocks for feature fusion
            * Multi-scale upsampling refinement
        - Attention: Spatial attention for sparse object detection
        - Outputs: Binary mask logits + flow field vectors
    
    Args:
        encoder_variant (str): EfficientNetV2 variant (default: 'rw_s')
        pyramid_channels (list): Channel dimensions for pyramid levels
        use_spatial_attention (bool): Enable spatial attention module
        spatial_kernel_size (int): Kernel size for spatial attention
        dropout_rate (float): Dropout rate in decoder
    """
    
    def __init__(
        self,
        encoder_variant='rw_s',
        pyramid_channels=[24, 48, 64, 160],
        use_spatial_attention=True,
        spatial_kernel_size=11,
        dropout_rate=0.15
    ):
        super().__init__()
        
        self.use_spatial_attention = use_spatial_attention
        self.pyramid_channels = pyramid_channels
        
        # Multi-scale encoder
        self.encoder = MultiScaleEncoder(
            variant=encoder_variant,
            pyramid_channels=pyramid_channels
        )
        
        # Style vector dimension
        style_dim = pyramid_channels[-1]  # Default: 160
        pyramid_dim = 32
        
        # Pyramid processing blocks
        self.pyramid_block4 = DualResidualBlock(pyramid_dim, style_dim)
        self.pyramid_block3 = DualResidualBlock(pyramid_dim, style_dim)
        self.pyramid_block2 = DualResidualBlock(pyramid_dim, style_dim)
        self.pyramid_block1 = DualResidualBlock(pyramid_dim, style_dim)
        
        # Channel reduction for all levels
        self.lateral_conv4 = nn.Conv2d(pyramid_channels[3], pyramid_dim, kernel_size=1, bias=False)
        self.lateral_conv3 = nn.Conv2d(pyramid_channels[2], pyramid_dim, kernel_size=1, bias=False)
        self.lateral_conv2 = nn.Conv2d(pyramid_channels[1], pyramid_dim, kernel_size=1, bias=False)
        self.lateral_conv1 = nn.Conv2d(pyramid_channels[0], pyramid_dim, kernel_size=1, bias=False)
        
        # Upsampling refinement blocks
        self.upsample_blocks = nn.ModuleList([
            DualResidualBlock(pyramid_dim, style_dim) for _ in range(3)
        ])
        
        # Final feature projection
        self.output_projection = NormProjection(pyramid_dim, pyramid_dim)
        
        # Spatial attention
        if use_spatial_attention:
            self.spatial_attention = SparseAttention(spatial_kernel_size)
        
        # Dropout
        self.dropout = nn.Dropout2d(dropout_rate)
        
        # Output heads
        self.mask_head = nn.Conv2d(pyramid_dim, 1, kernel_size=1)  # Binary segmentation
        self.flow_head = nn.Conv2d(pyramid_dim, 2, kernel_size=1)  # Flow vectors
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for sparse condensate segmentation."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
        # Mask head initialization
        nn.init.xavier_uniform_(self.mask_head.weight, gain=0.1)
        if self.mask_head.bias is not None:
            nn.init.constant_(self.mask_head.bias, -2.0)  # Bias toward background
        
        # Flow head: initialize to zero
        nn.init.zeros_(self.flow_head.weight)
        if self.flow_head.bias is not None:
            nn.init.zeros_(self.flow_head.bias)
    
    def upsample_and_refine(self, x, style_vector, block):
        """Upsample features and apply refinement block."""
        x_up = F.interpolate(x, scale_factor=2, mode='nearest')
        return block(x_up, torch.zeros_like(x_up), style_vector)
    
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass.
        
        Args:
            x: Input tensor of shape (B, 1, H, W) - grayscale microscopy image
            
        Returns:
            Dictionary with keys:
                - 'mask': Shape (B, 1, H, W) - binary mask logits
                - 'flows': Shape (B, 2, H, W) - flow field vectors (dy, dx)
        """
        # Input validation
        if torch.isnan(x).any() or torch.isinf(x).any():
            raise ValueError("Input contains NaN or Inf values")
        
        # Normalize input
        x_mean = x.mean(dim=[2, 3], keepdim=True)
        x_std = x.std(dim=[2, 3], keepdim=True) + 1e-8
        x_normalized = (x - x_mean) / x_std
        
        # Extract multi-scale features
        features = self.encoder(x_normalized)
        C1, C2, C3, C4 = features
        
        # Compute global style vector
        style_vector = F.adaptive_avg_pool2d(C4, 1).squeeze(-1).squeeze(-1)
        style_vector = F.normalize(style_vector, p=2, dim=1)
        
        # Build feature pyramid
        C4_reduced = self.lateral_conv4(C4)
        P4 = self.pyramid_block4(C4_reduced, C4_reduced, style_vector)
        
        P4_up = F.interpolate(P4, size=C3.shape[2:], mode='nearest')
        C3_reduced = self.lateral_conv3(C3)
        P3 = self.pyramid_block3(P4_up, C3_reduced, style_vector)
        
        P3_up = F.interpolate(P3, size=C2.shape[2:], mode='nearest')
        C2_reduced = self.lateral_conv2(C2)
        P2 = self.pyramid_block2(P3_up, C2_reduced, style_vector)
        
        P2_up = F.interpolate(P2, size=C1.shape[2:], mode='nearest')
        C1_reduced = self.lateral_conv1(C1)
        P1 = self.pyramid_block1(P2_up, C1_reduced, style_vector)
        
        # Multi-scale upsampling refinement
        P4_refined = self.upsample_and_refine(P4, style_vector, self.upsample_blocks[0])
        P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[1])
        P4_refined = self.upsample_and_refine(P4_refined, style_vector, self.upsample_blocks[2])
        
        P3_refined = self.upsample_and_refine(P3, style_vector, self.upsample_blocks[0])
        P3_refined = self.upsample_and_refine(P3_refined, style_vector, self.upsample_blocks[1])
        
        P2_refined = self.upsample_and_refine(P2, style_vector, self.upsample_blocks[0])
        
        # Combine all scales
        combined = P1 + P2_refined + P3_refined + P4_refined
        features_final = self.output_projection(combined)
        
        # Apply spatial attention
        if self.use_spatial_attention:
            features_final = self.spatial_attention(features_final)
        
        # Apply dropout
        features_final = self.dropout(features_final)
        
        # Generate outputs
        mask_logits = self.mask_head(features_final)
        flow_vectors = self.flow_head(features_final)
        
        # Upsample to match input size
        target_size = x.shape[2:]
        mask_logits = F.interpolate(mask_logits, size=target_size, mode='bilinear', align_corners=False)
        flow_vectors = F.interpolate(flow_vectors, size=target_size, mode='bilinear', align_corners=False)
        
        return {
            'mask': mask_logits,
            'flows': flow_vectors
        }


def load_condensatepose_model(
    checkpoint_path: str,
    device: str = 'cuda'
) -> CondensatePoseModel:
    """
    Load a trained CondensatePose model from checkpoint.
    
    Args:
        checkpoint_path: Path to model checkpoint (.pth file)
        device: Device to load model on ('cuda' or 'cpu')
        
    Returns:
        Loaded model in eval mode
        
    Example:
        >>> model = load_condensatepose_model('model_weights.pth', device='cuda')
        >>> model.eval()
        >>> 
        >>> # Run inference
        >>> with torch.no_grad():
        >>>     outputs = model(image_tensor)
        >>> mask_logits = outputs['mask']
        >>> flows = outputs['flows']
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = checkpoint.get('model_config', {})
    
    model = CondensatePoseModel(
        encoder_variant=config.get('encoder_variant', 'rw_s'),
        pyramid_channels=config.get('pyramid_channels', [24, 48, 64, 160]),
        use_spatial_attention=config.get('use_spatial_attention', True),
        spatial_kernel_size=config.get('spatial_kernel_size', 11),
        dropout_rate=config.get('dropout_rate', 0.15),
    )
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model


# Alias for compatibility
CondensateSegmentationNet = CondensatePoseModel