"""PyTorch Sybil model for lung cancer risk prediction""" import torch import torch.nn as nn import torchvision from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from typing import Optional, Dict, List, Tuple import numpy as np from dataclasses import dataclass try: from .configuration_sybil import SybilConfig except ImportError: from configuration_sybil import SybilConfig @dataclass class SybilOutput(BaseModelOutput): """ Base class for Sybil model outputs. Args: risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`): Predicted risk scores for each year up to max_followup. image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*): Attention weights over image pixels. volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*): Attention weights over CT scan slices. hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*): Hidden states from the pooling layer. """ risk_scores: torch.FloatTensor = None image_attention: Optional[torch.FloatTensor] = None volume_attention: Optional[torch.FloatTensor] = None hidden_states: Optional[torch.FloatTensor] = None class CumulativeProbabilityLayer(nn.Module): """ Cumulative probability layer for survival prediction. Matches the original Sybil implementation exactly with: - hazard_fc: Year-specific hazards (can be zero after ReLU) - base_hazard_fc: Base hazard shared across all years - Triangular masking for cumulative hazard computation """ def __init__(self, hidden_dim: int, max_followup: int = 6): super().__init__() self.max_followup = max_followup # Year-specific hazards self.hazard_fc = nn.Linear(hidden_dim, max_followup) # Base hazard (shared across years) self.base_hazard_fc = nn.Linear(hidden_dim, 1) self.relu = nn.ReLU(inplace=True) # Upper triangular mask for cumulative computation mask = torch.ones([max_followup, max_followup]) mask = torch.tril(mask, diagonal=0) mask = torch.nn.Parameter(torch.t(mask), requires_grad=False) self.register_parameter("upper_triangular_mask", mask) def hazards(self, x): """Compute positive hazards using ReLU""" raw_hazard = self.hazard_fc(x) pos_hazard = self.relu(raw_hazard) return pos_hazard def forward(self, x): """ Compute cumulative probabilities matching original Sybil. Args: x: Hidden features [B, hidden_dim] Returns: Cumulative probabilities [B, max_followup] """ hazards = self.hazards(x) B, T = hazards.size() # Expand for masking: [B, T] -> [B, T, T] expanded_hazards = hazards.unsqueeze(-1).expand(B, T, T) # Apply triangular mask for cumulative sum masked_hazards = expanded_hazards * self.upper_triangular_mask # Base hazard (shared across years) base_hazard = self.base_hazard_fc(x) # Sum masked hazards and add base cum_prob = torch.sum(masked_hazards, dim=1) + base_hazard return cum_prob class GlobalMaxPool(nn.Module): """Pool to obtain the maximum value for each channel""" def __init__(self): super(GlobalMaxPool, self).__init__() def forward(self, x): """ Args: - x: tensor of shape (B, C, T, W, H) Returns: - output: dict. output['hidden'] is (B, C) """ spatially_flat_size = (*x.size()[:2], -1) x = x.view(spatially_flat_size) hidden, _ = torch.max(x, dim=-1) return {'hidden': hidden} class PerFrameMaxPool(nn.Module): """Pool to obtain the maximum value for each slice in 3D input""" def __init__(self): super(PerFrameMaxPool, self).__init__() def forward(self, x): """ Args: - x: tensor of shape (B, C, T, W, H) Returns: - output: dict. + output['multi_image_hidden'] is (B, C, T) """ assert len(x.shape) == 5 output = {} spatially_flat_size = (*x.size()[:3], -1) x = x.view(spatially_flat_size) output['multi_image_hidden'], _ = torch.max(x, dim=-1) return output class Simple_AttentionPool(nn.Module): """Pool to learn an attention over the slices""" def __init__(self, **kwargs): super(Simple_AttentionPool, self).__init__() self.attention_fc = nn.Linear(kwargs['num_chan'], 1) self.softmax = nn.Softmax(dim=-1) self.logsoftmax = nn.LogSoftmax(dim=-1) def forward(self, x): """ Args: - x: tensor of shape (B, C, N) Returns: - output: dict + output['volume_attention']: tensor (B, N) + output['hidden']: tensor (B, C) """ output = {} B = x.shape[0] spatially_flat_size = (*x.size()[:2], -1) # B, C, N x = x.view(spatially_flat_size) attention_scores = self.attention_fc(x.transpose(1, 2)) # B, N, 1 output['volume_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, -1) attention_scores = self.softmax(attention_scores.transpose(1, 2)) # B, 1, N x = x * attention_scores # B, C, N output['hidden'] = torch.sum(x, dim=-1) return output class Simple_AttentionPool_MultiImg(nn.Module): """Pool to learn an attention over the slices and the volume""" def __init__(self, **kwargs): super(Simple_AttentionPool_MultiImg, self).__init__() self.attention_fc = nn.Linear(kwargs['num_chan'], 1) self.softmax = nn.Softmax(dim=-1) self.logsoftmax = nn.LogSoftmax(dim=-1) def forward(self, x): """ Args: - x: tensor of shape (B, C, T, W, H) Returns: - output: dict + output['image_attention']: tensor (B, T, W*H) + output['multi_image_hidden']: tensor (B, C, T) + output['hidden']: tensor (B, T*C) """ output = {} B, C, T, W, H = x.size() x = x.permute([0, 2, 1, 3, 4]) x = x.contiguous().view(B*T, C, W*H) attention_scores = self.attention_fc(x.transpose(1, 2)) # BT, WH, 1 output['image_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, T, -1) attention_scores = self.softmax(attention_scores.transpose(1, 2)) # BT, 1, WH x = x * attention_scores # BT, C, WH x = torch.sum(x, dim=-1) output['multi_image_hidden'] = x.view(B, T, C).permute([0, 2, 1]).contiguous() output['hidden'] = x.view(B, T * C) return output class Conv1d_AttnPool(nn.Module): """Pool to learn an attention over the slices after convolution""" def __init__(self, **kwargs): super(Conv1d_AttnPool, self).__init__() self.conv1d = nn.Conv1d( kwargs['num_chan'], kwargs['num_chan'], kernel_size=kwargs['conv_pool_kernel_size'], stride=kwargs['stride'], padding=kwargs['conv_pool_kernel_size']//2, bias=False ) self.aggregate = Simple_AttentionPool(**kwargs) def forward(self, x): """ Args: - x: tensor of shape (B, C, T) Returns: - output: dict + output['attention_scores']: tensor (B, C) + output['hidden']: tensor (B, C) """ # X: B, C, N x = self.conv1d(x) # B, C, N' return self.aggregate(x) class MultiAttentionPool(nn.Module): """Multi-attention pooling layer for CT scan aggregation - matches original Sybil architecture""" def __init__(self, channels: int = 512): super().__init__() params = { 'num_chan': 512, 'conv_pool_kernel_size': 11, 'stride': 1 } # Define all pooling sub-modules matching original Sybil self.image_pool1 = Simple_AttentionPool_MultiImg(**params) self.volume_pool1 = Simple_AttentionPool(**params) self.image_pool2 = PerFrameMaxPool() self.volume_pool2 = Conv1d_AttnPool(**params) self.global_max_pool = GlobalMaxPool() # Final linear layers to combine features self.multi_img_hidden_fc = nn.Linear(2 * 512, 512) self.hidden_fc = nn.Linear(3 * 512, 512) def forward(self, x): """ Args: x: tensor of shape (B, C, T, W, H) where - B: batch size - C: channels (512) - T: temporal/depth dimension (slices) - W, H: spatial dimensions Returns: output: dict with keys: - 'hidden': (B, 512) - final aggregated features - 'image_attention_1': (B, T, W*H) - image attention scores - 'volume_attention_1': (B, T) - volume attention scores - 'image_attention_2': None (no attention for max pool) - 'volume_attention_2': (B, T) - volume attention scores - 'multi_image_hidden': (B, 512, T) - intermediate features - 'maxpool_hidden': (B, 512) - max pooled features """ output = {} # First attention pooling pathway image_pool_out1 = self.image_pool1(x) # Keys: "multi_image_hidden" (B, C, T), "image_attention" (B, T, W*H), "hidden" (B, T*C) volume_pool_out1 = self.volume_pool1(image_pool_out1['multi_image_hidden']) # Keys: "hidden" (B, C), "volume_attention" (B, T) # Second max pooling pathway image_pool_out2 = self.image_pool2(x) # Keys: "multi_image_hidden" (B, C, T) volume_pool_out2 = self.volume_pool2(image_pool_out2['multi_image_hidden']) # Keys: "hidden" (B, C), "volume_attention" (B, T) # Collect all pooling outputs with numbered suffixes for pool_out, num in [(image_pool_out1, 1), (volume_pool_out1, 1), (image_pool_out2, 2), (volume_pool_out2, 2)]: for key, val in pool_out.items(): output['{}_{}'.format(key, num)] = val # Global max pooling maxpool_out = self.global_max_pool(x) output['maxpool_hidden'] = maxpool_out['hidden'] # Combine multi-image features from both pathways multi_image_hidden = torch.cat( [image_pool_out1['multi_image_hidden'], image_pool_out2['multi_image_hidden']], dim=-2 ) # (B, C, 2*T) output['multi_image_hidden'] = self.multi_img_hidden_fc( multi_image_hidden.permute([0, 2, 1]).contiguous() ).permute([0, 2, 1]).contiguous() # (B, 512, T) # Combine all volume-level features hidden = torch.cat( [volume_pool_out1['hidden'], volume_pool_out2['hidden'], output['maxpool_hidden']], dim=-1 ) # (B, 3*512) output['hidden'] = self.hidden_fc(hidden) # (B, 512) return output class SybilPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = SybilConfig base_model_prefix = "sybil" supports_gradient_checkpointing = False def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Conv3d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') if module.bias is not None: module.bias.data.zero_() class SybilForRiskPrediction(SybilPreTrainedModel): """ Sybil model for lung cancer risk prediction from CT scans. This model takes 3D CT scan volumes as input and predicts cancer risk scores for multiple future time points (typically 1-6 years). """ def __init__(self, config: SybilConfig): super().__init__(config) self.config = config # Use pretrained R3D-18 as backbone encoder = torchvision.models.video.r3d_18(pretrained=True) self.image_encoder = nn.Sequential(*list(encoder.children())[:-2]) # Multi-attention pooling self.pool = MultiAttentionPool(channels=512) # Classification layers self.relu = nn.ReLU(inplace=False) self.dropout = nn.Dropout(p=config.dropout) # Risk prediction layer self.prob_of_failure_layer = CumulativeProbabilityLayer( config.hidden_dim, max_followup=config.max_followup ) # Calibrator for ensemble predictions self.calibrator = None if config.calibrator_data: self.set_calibrator(config.calibrator_data) # Initialize weights self.post_init() def set_calibrator(self, calibrator_data: Dict): """Set calibration data for risk score adjustment""" self.calibrator = calibrator_data def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor: """Apply calibration to raw risk scores""" if self.calibrator is None: return scores # Convert to numpy for calibration scores_np = scores.detach().cpu().numpy() calibrated = np.zeros_like(scores_np) # Apply calibration for each year for year in range(scores_np.shape[1]): year_key = f"Year{year + 1}" if year_key in self.calibrator: # Apply calibration transformation calibrated[:, year] = self._apply_calibration( scores_np[:, year], self.calibrator[year_key] ) else: calibrated[:, year] = scores_np[:, year] return torch.from_numpy(calibrated).to(scores.device) def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray: """Apply specific calibration transformation""" # Simplified calibration - in practice, this would use the full calibration model # from the original Sybil implementation return scores # Placeholder for now def forward( self, pixel_values: torch.FloatTensor, return_attentions: bool = False, return_dict: bool = True, ) -> SybilOutput: """ Forward pass of the Sybil model. Args: pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`): Pixel values of CT scan volumes. return_attentions: (`bool`, *optional*, defaults to `False`): Whether to return attention weights. return_dict: (`bool`, *optional*, defaults to `True`): Whether to return a `SybilOutput` instead of a plain tuple. Returns: `SybilOutput` or tuple """ # Extract features using 3D CNN backbone features = self.image_encoder(pixel_values) # Apply multi-attention pooling pool_output = self.pool(features) # Apply ReLU and dropout hidden = self.relu(pool_output['hidden']) hidden = self.dropout(hidden) # Predict risk scores risk_logits = self.prob_of_failure_layer(hidden) risk_scores = torch.sigmoid(risk_logits) # Apply calibration if available risk_scores = self._calibrate_scores(risk_scores) if not return_dict: outputs = (risk_scores,) if return_attentions: outputs = outputs + (pool_output.get('image_attention_1'), pool_output.get('volume_attention_1')) return outputs return SybilOutput( risk_scores=risk_scores, image_attention=pool_output.get('image_attention_1') if return_attentions else None, volume_attention=pool_output.get('volume_attention_1') if return_attentions else None, hidden_states=hidden if return_attentions else None ) @classmethod def from_pretrained_ensemble( cls, pretrained_model_name_or_path, checkpoint_paths: List[str], calibrator_path: Optional[str] = None, **kwargs ): """ Load an ensemble of Sybil models from checkpoints. Args: pretrained_model_name_or_path: Path to the pretrained model or model identifier. checkpoint_paths: List of paths to individual model checkpoints. calibrator_path: Path to calibration data. **kwargs: Additional keyword arguments for model initialization. Returns: SybilEnsemble: An ensemble of Sybil models. """ config = kwargs.pop("config", None) if config is None: config = SybilConfig.from_pretrained(pretrained_model_name_or_path) # Load calibrator if provided calibrator_data = None if calibrator_path: import json with open(calibrator_path, 'r') as f: calibrator_data = json.load(f) config.calibrator_data = calibrator_data # Create ensemble models = [] for checkpoint_path in checkpoint_paths: model = cls(config) # Load checkpoint weights checkpoint = torch.load(checkpoint_path, map_location='cpu') # Remove 'model.' prefix from state dict keys if present state_dict = {} for k, v in checkpoint['state_dict'].items(): if k.startswith('model.'): state_dict[k[6:]] = v else: state_dict[k] = v # Map to new model structure mapped_state_dict = model._map_checkpoint_weights(state_dict) model.load_state_dict(mapped_state_dict, strict=False) models.append(model) return SybilEnsemble(models, config) def _map_checkpoint_weights(self, state_dict: Dict) -> Dict: """Map original Sybil checkpoint weights to new structure""" mapped = {} # Map encoder weights for k, v in state_dict.items(): if k.startswith('image_encoder'): mapped[k] = v elif k.startswith('pool'): # Map pooling layer weights mapped[k] = v elif k.startswith('prob_of_failure_layer'): # Map final prediction layer mapped[k] = v return mapped class SybilEnsemble: """Ensemble of Sybil models for improved predictions""" def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig): self.models = models self.config = config self.device = None def to(self, device): """Move all models to device""" self.device = device for model in self.models: model.to(device) return self def eval(self): """Set all models to evaluation mode""" for model in self.models: model.eval() def __call__( self, pixel_values: torch.FloatTensor, return_attentions: bool = False, ) -> SybilOutput: """ Run inference with ensemble voting. Args: pixel_values: Input CT scan volumes. return_attentions: Whether to return attention maps. Returns: SybilOutput with averaged predictions from all models. """ all_risk_scores = [] all_image_attentions = [] all_volume_attentions = [] with torch.no_grad(): for model in self.models: output = model( pixel_values=pixel_values, return_attentions=return_attentions ) all_risk_scores.append(output.risk_scores) if return_attentions: all_image_attentions.append(output.image_attention) all_volume_attentions.append(output.volume_attention) # Average predictions risk_scores = torch.stack(all_risk_scores).mean(dim=0) # Average attentions if requested image_attention = None volume_attention = None if return_attentions: image_attention = torch.stack(all_image_attentions).mean(dim=0) volume_attention = torch.stack(all_volume_attentions).mean(dim=0) return SybilOutput( risk_scores=risk_scores, image_attention=image_attention, volume_attention=volume_attention )