sybil / modeling_sybil.py
Aakash-Tripathi's picture
Upload 20 files
cf14762 verified
"""PyTorch Sybil model for lung cancer risk prediction"""
import torch
import torch.nn as nn
import torchvision
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Dict, List, Tuple
import numpy as np
from dataclasses import dataclass
try:
from .configuration_sybil import SybilConfig
except ImportError:
from configuration_sybil import SybilConfig
@dataclass
class SybilOutput(BaseModelOutput):
"""
Base class for Sybil model outputs.
Args:
risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`):
Predicted risk scores for each year up to max_followup.
image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*):
Attention weights over image pixels.
volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*):
Attention weights over CT scan slices.
hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*):
Hidden states from the pooling layer.
"""
risk_scores: torch.FloatTensor = None
image_attention: Optional[torch.FloatTensor] = None
volume_attention: Optional[torch.FloatTensor] = None
hidden_states: Optional[torch.FloatTensor] = None
class CumulativeProbabilityLayer(nn.Module):
"""
Cumulative probability layer for survival prediction.
Matches the original Sybil implementation exactly with:
- hazard_fc: Year-specific hazards (can be zero after ReLU)
- base_hazard_fc: Base hazard shared across all years
- Triangular masking for cumulative hazard computation
"""
def __init__(self, hidden_dim: int, max_followup: int = 6):
super().__init__()
self.max_followup = max_followup
# Year-specific hazards
self.hazard_fc = nn.Linear(hidden_dim, max_followup)
# Base hazard (shared across years)
self.base_hazard_fc = nn.Linear(hidden_dim, 1)
self.relu = nn.ReLU(inplace=True)
# Upper triangular mask for cumulative computation
mask = torch.ones([max_followup, max_followup])
mask = torch.tril(mask, diagonal=0)
mask = torch.nn.Parameter(torch.t(mask), requires_grad=False)
self.register_parameter("upper_triangular_mask", mask)
def hazards(self, x):
"""Compute positive hazards using ReLU"""
raw_hazard = self.hazard_fc(x)
pos_hazard = self.relu(raw_hazard)
return pos_hazard
def forward(self, x):
"""
Compute cumulative probabilities matching original Sybil.
Args:
x: Hidden features [B, hidden_dim]
Returns:
Cumulative probabilities [B, max_followup]
"""
hazards = self.hazards(x)
B, T = hazards.size()
# Expand for masking: [B, T] -> [B, T, T]
expanded_hazards = hazards.unsqueeze(-1).expand(B, T, T)
# Apply triangular mask for cumulative sum
masked_hazards = expanded_hazards * self.upper_triangular_mask
# Base hazard (shared across years)
base_hazard = self.base_hazard_fc(x)
# Sum masked hazards and add base
cum_prob = torch.sum(masked_hazards, dim=1) + base_hazard
return cum_prob
class GlobalMaxPool(nn.Module):
"""Pool to obtain the maximum value for each channel"""
def __init__(self):
super(GlobalMaxPool, self).__init__()
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict. output['hidden'] is (B, C)
"""
spatially_flat_size = (*x.size()[:2], -1)
x = x.view(spatially_flat_size)
hidden, _ = torch.max(x, dim=-1)
return {'hidden': hidden}
class PerFrameMaxPool(nn.Module):
"""Pool to obtain the maximum value for each slice in 3D input"""
def __init__(self):
super(PerFrameMaxPool, self).__init__()
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict.
+ output['multi_image_hidden'] is (B, C, T)
"""
assert len(x.shape) == 5
output = {}
spatially_flat_size = (*x.size()[:3], -1)
x = x.view(spatially_flat_size)
output['multi_image_hidden'], _ = torch.max(x, dim=-1)
return output
class Simple_AttentionPool(nn.Module):
"""Pool to learn an attention over the slices"""
def __init__(self, **kwargs):
super(Simple_AttentionPool, self).__init__()
self.attention_fc = nn.Linear(kwargs['num_chan'], 1)
self.softmax = nn.Softmax(dim=-1)
self.logsoftmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, N)
Returns:
- output: dict
+ output['volume_attention']: tensor (B, N)
+ output['hidden']: tensor (B, C)
"""
output = {}
B = x.shape[0]
spatially_flat_size = (*x.size()[:2], -1) # B, C, N
x = x.view(spatially_flat_size)
attention_scores = self.attention_fc(x.transpose(1, 2)) # B, N, 1
output['volume_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, -1)
attention_scores = self.softmax(attention_scores.transpose(1, 2)) # B, 1, N
x = x * attention_scores # B, C, N
output['hidden'] = torch.sum(x, dim=-1)
return output
class Simple_AttentionPool_MultiImg(nn.Module):
"""Pool to learn an attention over the slices and the volume"""
def __init__(self, **kwargs):
super(Simple_AttentionPool_MultiImg, self).__init__()
self.attention_fc = nn.Linear(kwargs['num_chan'], 1)
self.softmax = nn.Softmax(dim=-1)
self.logsoftmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict
+ output['image_attention']: tensor (B, T, W*H)
+ output['multi_image_hidden']: tensor (B, C, T)
+ output['hidden']: tensor (B, T*C)
"""
output = {}
B, C, T, W, H = x.size()
x = x.permute([0, 2, 1, 3, 4])
x = x.contiguous().view(B*T, C, W*H)
attention_scores = self.attention_fc(x.transpose(1, 2)) # BT, WH, 1
output['image_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, T, -1)
attention_scores = self.softmax(attention_scores.transpose(1, 2)) # BT, 1, WH
x = x * attention_scores # BT, C, WH
x = torch.sum(x, dim=-1)
output['multi_image_hidden'] = x.view(B, T, C).permute([0, 2, 1]).contiguous()
output['hidden'] = x.view(B, T * C)
return output
class Conv1d_AttnPool(nn.Module):
"""Pool to learn an attention over the slices after convolution"""
def __init__(self, **kwargs):
super(Conv1d_AttnPool, self).__init__()
self.conv1d = nn.Conv1d(
kwargs['num_chan'],
kwargs['num_chan'],
kernel_size=kwargs['conv_pool_kernel_size'],
stride=kwargs['stride'],
padding=kwargs['conv_pool_kernel_size']//2,
bias=False
)
self.aggregate = Simple_AttentionPool(**kwargs)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T)
Returns:
- output: dict
+ output['attention_scores']: tensor (B, C)
+ output['hidden']: tensor (B, C)
"""
# X: B, C, N
x = self.conv1d(x) # B, C, N'
return self.aggregate(x)
class MultiAttentionPool(nn.Module):
"""Multi-attention pooling layer for CT scan aggregation - matches original Sybil architecture"""
def __init__(self, channels: int = 512):
super().__init__()
params = {
'num_chan': 512,
'conv_pool_kernel_size': 11,
'stride': 1
}
# Define all pooling sub-modules matching original Sybil
self.image_pool1 = Simple_AttentionPool_MultiImg(**params)
self.volume_pool1 = Simple_AttentionPool(**params)
self.image_pool2 = PerFrameMaxPool()
self.volume_pool2 = Conv1d_AttnPool(**params)
self.global_max_pool = GlobalMaxPool()
# Final linear layers to combine features
self.multi_img_hidden_fc = nn.Linear(2 * 512, 512)
self.hidden_fc = nn.Linear(3 * 512, 512)
def forward(self, x):
"""
Args:
x: tensor of shape (B, C, T, W, H) where
- B: batch size
- C: channels (512)
- T: temporal/depth dimension (slices)
- W, H: spatial dimensions
Returns:
output: dict with keys:
- 'hidden': (B, 512) - final aggregated features
- 'image_attention_1': (B, T, W*H) - image attention scores
- 'volume_attention_1': (B, T) - volume attention scores
- 'image_attention_2': None (no attention for max pool)
- 'volume_attention_2': (B, T) - volume attention scores
- 'multi_image_hidden': (B, 512, T) - intermediate features
- 'maxpool_hidden': (B, 512) - max pooled features
"""
output = {}
# First attention pooling pathway
image_pool_out1 = self.image_pool1(x)
# Keys: "multi_image_hidden" (B, C, T), "image_attention" (B, T, W*H), "hidden" (B, T*C)
volume_pool_out1 = self.volume_pool1(image_pool_out1['multi_image_hidden'])
# Keys: "hidden" (B, C), "volume_attention" (B, T)
# Second max pooling pathway
image_pool_out2 = self.image_pool2(x)
# Keys: "multi_image_hidden" (B, C, T)
volume_pool_out2 = self.volume_pool2(image_pool_out2['multi_image_hidden'])
# Keys: "hidden" (B, C), "volume_attention" (B, T)
# Collect all pooling outputs with numbered suffixes
for pool_out, num in [(image_pool_out1, 1), (volume_pool_out1, 1),
(image_pool_out2, 2), (volume_pool_out2, 2)]:
for key, val in pool_out.items():
output['{}_{}'.format(key, num)] = val
# Global max pooling
maxpool_out = self.global_max_pool(x)
output['maxpool_hidden'] = maxpool_out['hidden']
# Combine multi-image features from both pathways
multi_image_hidden = torch.cat(
[image_pool_out1['multi_image_hidden'], image_pool_out2['multi_image_hidden']],
dim=-2
) # (B, C, 2*T)
output['multi_image_hidden'] = self.multi_img_hidden_fc(
multi_image_hidden.permute([0, 2, 1]).contiguous()
).permute([0, 2, 1]).contiguous() # (B, 512, T)
# Combine all volume-level features
hidden = torch.cat(
[volume_pool_out1['hidden'], volume_pool_out2['hidden'], output['maxpool_hidden']],
dim=-1
) # (B, 3*512)
output['hidden'] = self.hidden_fc(hidden) # (B, 512)
return output
class SybilPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface
for downloading and loading pretrained models.
"""
config_class = SybilConfig
base_model_prefix = "sybil"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv3d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
module.bias.data.zero_()
class SybilForRiskPrediction(SybilPreTrainedModel):
"""
Sybil model for lung cancer risk prediction from CT scans.
This model takes 3D CT scan volumes as input and predicts cancer risk scores
for multiple future time points (typically 1-6 years).
"""
def __init__(self, config: SybilConfig):
super().__init__(config)
self.config = config
# Use pretrained R3D-18 as backbone
encoder = torchvision.models.video.r3d_18(pretrained=True)
self.image_encoder = nn.Sequential(*list(encoder.children())[:-2])
# Multi-attention pooling
self.pool = MultiAttentionPool(channels=512)
# Classification layers
self.relu = nn.ReLU(inplace=False)
self.dropout = nn.Dropout(p=config.dropout)
# Risk prediction layer
self.prob_of_failure_layer = CumulativeProbabilityLayer(
config.hidden_dim,
max_followup=config.max_followup
)
# Calibrator for ensemble predictions
self.calibrator = None
if config.calibrator_data:
self.set_calibrator(config.calibrator_data)
# Initialize weights
self.post_init()
def set_calibrator(self, calibrator_data: Dict):
"""Set calibration data for risk score adjustment"""
self.calibrator = calibrator_data
def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor:
"""Apply calibration to raw risk scores"""
if self.calibrator is None:
return scores
# Convert to numpy for calibration
scores_np = scores.detach().cpu().numpy()
calibrated = np.zeros_like(scores_np)
# Apply calibration for each year
for year in range(scores_np.shape[1]):
year_key = f"Year{year + 1}"
if year_key in self.calibrator:
# Apply calibration transformation
calibrated[:, year] = self._apply_calibration(
scores_np[:, year],
self.calibrator[year_key]
)
else:
calibrated[:, year] = scores_np[:, year]
return torch.from_numpy(calibrated).to(scores.device)
def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray:
"""Apply specific calibration transformation"""
# Simplified calibration - in practice, this would use the full calibration model
# from the original Sybil implementation
return scores # Placeholder for now
def forward(
self,
pixel_values: torch.FloatTensor,
return_attentions: bool = False,
return_dict: bool = True,
) -> SybilOutput:
"""
Forward pass of the Sybil model.
Args:
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`):
Pixel values of CT scan volumes.
return_attentions: (`bool`, *optional*, defaults to `False`):
Whether to return attention weights.
return_dict: (`bool`, *optional*, defaults to `True`):
Whether to return a `SybilOutput` instead of a plain tuple.
Returns:
`SybilOutput` or tuple
"""
# Extract features using 3D CNN backbone
features = self.image_encoder(pixel_values)
# Apply multi-attention pooling
pool_output = self.pool(features)
# Apply ReLU and dropout
hidden = self.relu(pool_output['hidden'])
hidden = self.dropout(hidden)
# Predict risk scores
risk_logits = self.prob_of_failure_layer(hidden)
risk_scores = torch.sigmoid(risk_logits)
# Apply calibration if available
risk_scores = self._calibrate_scores(risk_scores)
if not return_dict:
outputs = (risk_scores,)
if return_attentions:
outputs = outputs + (pool_output.get('image_attention_1'),
pool_output.get('volume_attention_1'))
return outputs
return SybilOutput(
risk_scores=risk_scores,
image_attention=pool_output.get('image_attention_1') if return_attentions else None,
volume_attention=pool_output.get('volume_attention_1') if return_attentions else None,
hidden_states=hidden if return_attentions else None
)
@classmethod
def from_pretrained_ensemble(
cls,
pretrained_model_name_or_path,
checkpoint_paths: List[str],
calibrator_path: Optional[str] = None,
**kwargs
):
"""
Load an ensemble of Sybil models from checkpoints.
Args:
pretrained_model_name_or_path: Path to the pretrained model or model identifier.
checkpoint_paths: List of paths to individual model checkpoints.
calibrator_path: Path to calibration data.
**kwargs: Additional keyword arguments for model initialization.
Returns:
SybilEnsemble: An ensemble of Sybil models.
"""
config = kwargs.pop("config", None)
if config is None:
config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
# Load calibrator if provided
calibrator_data = None
if calibrator_path:
import json
with open(calibrator_path, 'r') as f:
calibrator_data = json.load(f)
config.calibrator_data = calibrator_data
# Create ensemble
models = []
for checkpoint_path in checkpoint_paths:
model = cls(config)
# Load checkpoint weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Remove 'model.' prefix from state dict keys if present
state_dict = {}
for k, v in checkpoint['state_dict'].items():
if k.startswith('model.'):
state_dict[k[6:]] = v
else:
state_dict[k] = v
# Map to new model structure
mapped_state_dict = model._map_checkpoint_weights(state_dict)
model.load_state_dict(mapped_state_dict, strict=False)
models.append(model)
return SybilEnsemble(models, config)
def _map_checkpoint_weights(self, state_dict: Dict) -> Dict:
"""Map original Sybil checkpoint weights to new structure"""
mapped = {}
# Map encoder weights
for k, v in state_dict.items():
if k.startswith('image_encoder'):
mapped[k] = v
elif k.startswith('pool'):
# Map pooling layer weights
mapped[k] = v
elif k.startswith('prob_of_failure_layer'):
# Map final prediction layer
mapped[k] = v
return mapped
class SybilEnsemble:
"""Ensemble of Sybil models for improved predictions"""
def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig):
self.models = models
self.config = config
self.device = None
def to(self, device):
"""Move all models to device"""
self.device = device
for model in self.models:
model.to(device)
return self
def eval(self):
"""Set all models to evaluation mode"""
for model in self.models:
model.eval()
def __call__(
self,
pixel_values: torch.FloatTensor,
return_attentions: bool = False,
) -> SybilOutput:
"""
Run inference with ensemble voting.
Args:
pixel_values: Input CT scan volumes.
return_attentions: Whether to return attention maps.
Returns:
SybilOutput with averaged predictions from all models.
"""
all_risk_scores = []
all_image_attentions = []
all_volume_attentions = []
with torch.no_grad():
for model in self.models:
output = model(
pixel_values=pixel_values,
return_attentions=return_attentions
)
all_risk_scores.append(output.risk_scores)
if return_attentions:
all_image_attentions.append(output.image_attention)
all_volume_attentions.append(output.volume_attention)
# Average predictions
risk_scores = torch.stack(all_risk_scores).mean(dim=0)
# Average attentions if requested
image_attention = None
volume_attention = None
if return_attentions:
image_attention = torch.stack(all_image_attentions).mean(dim=0)
volume_attention = torch.stack(all_volume_attentions).mean(dim=0)
return SybilOutput(
risk_scores=risk_scores,
image_attention=image_attention,
volume_attention=volume_attention
)