sybil / modeling_sybil.py
Aakash-Tripathi's picture
Upload 20 files
cf14762 verified
raw
history blame
21.1 kB
"""PyTorch Sybil model for lung cancer risk prediction"""
import torch
import torch.nn as nn
import torchvision
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from typing import Optional, Dict, List, Tuple
import numpy as np
from dataclasses import dataclass
try:
from .configuration_sybil import SybilConfig
except ImportError:
from configuration_sybil import SybilConfig
@dataclass
class SybilOutput(BaseModelOutput):
"""
Base class for Sybil model outputs.
Args:
risk_scores: (`torch.FloatTensor` of shape `(batch_size, max_followup)`):
Predicted risk scores for each year up to max_followup.
image_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices, height, width)`, *optional*):
Attention weights over image pixels.
volume_attention: (`torch.FloatTensor` of shape `(batch_size, num_slices)`, *optional*):
Attention weights over CT scan slices.
hidden_states: (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`, *optional*):
Hidden states from the pooling layer.
"""
risk_scores: torch.FloatTensor = None
image_attention: Optional[torch.FloatTensor] = None
volume_attention: Optional[torch.FloatTensor] = None
hidden_states: Optional[torch.FloatTensor] = None
class CumulativeProbabilityLayer(nn.Module):
"""
Cumulative probability layer for survival prediction.
Matches the original Sybil implementation exactly with:
- hazard_fc: Year-specific hazards (can be zero after ReLU)
- base_hazard_fc: Base hazard shared across all years
- Triangular masking for cumulative hazard computation
"""
def __init__(self, hidden_dim: int, max_followup: int = 6):
super().__init__()
self.max_followup = max_followup
# Year-specific hazards
self.hazard_fc = nn.Linear(hidden_dim, max_followup)
# Base hazard (shared across years)
self.base_hazard_fc = nn.Linear(hidden_dim, 1)
self.relu = nn.ReLU(inplace=True)
# Upper triangular mask for cumulative computation
mask = torch.ones([max_followup, max_followup])
mask = torch.tril(mask, diagonal=0)
mask = torch.nn.Parameter(torch.t(mask), requires_grad=False)
self.register_parameter("upper_triangular_mask", mask)
def hazards(self, x):
"""Compute positive hazards using ReLU"""
raw_hazard = self.hazard_fc(x)
pos_hazard = self.relu(raw_hazard)
return pos_hazard
def forward(self, x):
"""
Compute cumulative probabilities matching original Sybil.
Args:
x: Hidden features [B, hidden_dim]
Returns:
Cumulative probabilities [B, max_followup]
"""
hazards = self.hazards(x)
B, T = hazards.size()
# Expand for masking: [B, T] -> [B, T, T]
expanded_hazards = hazards.unsqueeze(-1).expand(B, T, T)
# Apply triangular mask for cumulative sum
masked_hazards = expanded_hazards * self.upper_triangular_mask
# Base hazard (shared across years)
base_hazard = self.base_hazard_fc(x)
# Sum masked hazards and add base
cum_prob = torch.sum(masked_hazards, dim=1) + base_hazard
return cum_prob
class GlobalMaxPool(nn.Module):
"""Pool to obtain the maximum value for each channel"""
def __init__(self):
super(GlobalMaxPool, self).__init__()
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict. output['hidden'] is (B, C)
"""
spatially_flat_size = (*x.size()[:2], -1)
x = x.view(spatially_flat_size)
hidden, _ = torch.max(x, dim=-1)
return {'hidden': hidden}
class PerFrameMaxPool(nn.Module):
"""Pool to obtain the maximum value for each slice in 3D input"""
def __init__(self):
super(PerFrameMaxPool, self).__init__()
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict.
+ output['multi_image_hidden'] is (B, C, T)
"""
assert len(x.shape) == 5
output = {}
spatially_flat_size = (*x.size()[:3], -1)
x = x.view(spatially_flat_size)
output['multi_image_hidden'], _ = torch.max(x, dim=-1)
return output
class Simple_AttentionPool(nn.Module):
"""Pool to learn an attention over the slices"""
def __init__(self, **kwargs):
super(Simple_AttentionPool, self).__init__()
self.attention_fc = nn.Linear(kwargs['num_chan'], 1)
self.softmax = nn.Softmax(dim=-1)
self.logsoftmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, N)
Returns:
- output: dict
+ output['volume_attention']: tensor (B, N)
+ output['hidden']: tensor (B, C)
"""
output = {}
B = x.shape[0]
spatially_flat_size = (*x.size()[:2], -1) # B, C, N
x = x.view(spatially_flat_size)
attention_scores = self.attention_fc(x.transpose(1, 2)) # B, N, 1
output['volume_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, -1)
attention_scores = self.softmax(attention_scores.transpose(1, 2)) # B, 1, N
x = x * attention_scores # B, C, N
output['hidden'] = torch.sum(x, dim=-1)
return output
class Simple_AttentionPool_MultiImg(nn.Module):
"""Pool to learn an attention over the slices and the volume"""
def __init__(self, **kwargs):
super(Simple_AttentionPool_MultiImg, self).__init__()
self.attention_fc = nn.Linear(kwargs['num_chan'], 1)
self.softmax = nn.Softmax(dim=-1)
self.logsoftmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T, W, H)
Returns:
- output: dict
+ output['image_attention']: tensor (B, T, W*H)
+ output['multi_image_hidden']: tensor (B, C, T)
+ output['hidden']: tensor (B, T*C)
"""
output = {}
B, C, T, W, H = x.size()
x = x.permute([0, 2, 1, 3, 4])
x = x.contiguous().view(B*T, C, W*H)
attention_scores = self.attention_fc(x.transpose(1, 2)) # BT, WH, 1
output['image_attention'] = self.logsoftmax(attention_scores.transpose(1, 2)).view(B, T, -1)
attention_scores = self.softmax(attention_scores.transpose(1, 2)) # BT, 1, WH
x = x * attention_scores # BT, C, WH
x = torch.sum(x, dim=-1)
output['multi_image_hidden'] = x.view(B, T, C).permute([0, 2, 1]).contiguous()
output['hidden'] = x.view(B, T * C)
return output
class Conv1d_AttnPool(nn.Module):
"""Pool to learn an attention over the slices after convolution"""
def __init__(self, **kwargs):
super(Conv1d_AttnPool, self).__init__()
self.conv1d = nn.Conv1d(
kwargs['num_chan'],
kwargs['num_chan'],
kernel_size=kwargs['conv_pool_kernel_size'],
stride=kwargs['stride'],
padding=kwargs['conv_pool_kernel_size']//2,
bias=False
)
self.aggregate = Simple_AttentionPool(**kwargs)
def forward(self, x):
"""
Args:
- x: tensor of shape (B, C, T)
Returns:
- output: dict
+ output['attention_scores']: tensor (B, C)
+ output['hidden']: tensor (B, C)
"""
# X: B, C, N
x = self.conv1d(x) # B, C, N'
return self.aggregate(x)
class MultiAttentionPool(nn.Module):
"""Multi-attention pooling layer for CT scan aggregation - matches original Sybil architecture"""
def __init__(self, channels: int = 512):
super().__init__()
params = {
'num_chan': 512,
'conv_pool_kernel_size': 11,
'stride': 1
}
# Define all pooling sub-modules matching original Sybil
self.image_pool1 = Simple_AttentionPool_MultiImg(**params)
self.volume_pool1 = Simple_AttentionPool(**params)
self.image_pool2 = PerFrameMaxPool()
self.volume_pool2 = Conv1d_AttnPool(**params)
self.global_max_pool = GlobalMaxPool()
# Final linear layers to combine features
self.multi_img_hidden_fc = nn.Linear(2 * 512, 512)
self.hidden_fc = nn.Linear(3 * 512, 512)
def forward(self, x):
"""
Args:
x: tensor of shape (B, C, T, W, H) where
- B: batch size
- C: channels (512)
- T: temporal/depth dimension (slices)
- W, H: spatial dimensions
Returns:
output: dict with keys:
- 'hidden': (B, 512) - final aggregated features
- 'image_attention_1': (B, T, W*H) - image attention scores
- 'volume_attention_1': (B, T) - volume attention scores
- 'image_attention_2': None (no attention for max pool)
- 'volume_attention_2': (B, T) - volume attention scores
- 'multi_image_hidden': (B, 512, T) - intermediate features
- 'maxpool_hidden': (B, 512) - max pooled features
"""
output = {}
# First attention pooling pathway
image_pool_out1 = self.image_pool1(x)
# Keys: "multi_image_hidden" (B, C, T), "image_attention" (B, T, W*H), "hidden" (B, T*C)
volume_pool_out1 = self.volume_pool1(image_pool_out1['multi_image_hidden'])
# Keys: "hidden" (B, C), "volume_attention" (B, T)
# Second max pooling pathway
image_pool_out2 = self.image_pool2(x)
# Keys: "multi_image_hidden" (B, C, T)
volume_pool_out2 = self.volume_pool2(image_pool_out2['multi_image_hidden'])
# Keys: "hidden" (B, C), "volume_attention" (B, T)
# Collect all pooling outputs with numbered suffixes
for pool_out, num in [(image_pool_out1, 1), (volume_pool_out1, 1),
(image_pool_out2, 2), (volume_pool_out2, 2)]:
for key, val in pool_out.items():
output['{}_{}'.format(key, num)] = val
# Global max pooling
maxpool_out = self.global_max_pool(x)
output['maxpool_hidden'] = maxpool_out['hidden']
# Combine multi-image features from both pathways
multi_image_hidden = torch.cat(
[image_pool_out1['multi_image_hidden'], image_pool_out2['multi_image_hidden']],
dim=-2
) # (B, C, 2*T)
output['multi_image_hidden'] = self.multi_img_hidden_fc(
multi_image_hidden.permute([0, 2, 1]).contiguous()
).permute([0, 2, 1]).contiguous() # (B, 512, T)
# Combine all volume-level features
hidden = torch.cat(
[volume_pool_out1['hidden'], volume_pool_out2['hidden'], output['maxpool_hidden']],
dim=-1
) # (B, 3*512)
output['hidden'] = self.hidden_fc(hidden) # (B, 512)
return output
class SybilPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface
for downloading and loading pretrained models.
"""
config_class = SybilConfig
base_model_prefix = "sybil"
supports_gradient_checkpointing = False
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv3d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
module.bias.data.zero_()
class SybilForRiskPrediction(SybilPreTrainedModel):
"""
Sybil model for lung cancer risk prediction from CT scans.
This model takes 3D CT scan volumes as input and predicts cancer risk scores
for multiple future time points (typically 1-6 years).
"""
def __init__(self, config: SybilConfig):
super().__init__(config)
self.config = config
# Use pretrained R3D-18 as backbone
encoder = torchvision.models.video.r3d_18(pretrained=True)
self.image_encoder = nn.Sequential(*list(encoder.children())[:-2])
# Multi-attention pooling
self.pool = MultiAttentionPool(channels=512)
# Classification layers
self.relu = nn.ReLU(inplace=False)
self.dropout = nn.Dropout(p=config.dropout)
# Risk prediction layer
self.prob_of_failure_layer = CumulativeProbabilityLayer(
config.hidden_dim,
max_followup=config.max_followup
)
# Calibrator for ensemble predictions
self.calibrator = None
if config.calibrator_data:
self.set_calibrator(config.calibrator_data)
# Initialize weights
self.post_init()
def set_calibrator(self, calibrator_data: Dict):
"""Set calibration data for risk score adjustment"""
self.calibrator = calibrator_data
def _calibrate_scores(self, scores: torch.Tensor) -> torch.Tensor:
"""Apply calibration to raw risk scores"""
if self.calibrator is None:
return scores
# Convert to numpy for calibration
scores_np = scores.detach().cpu().numpy()
calibrated = np.zeros_like(scores_np)
# Apply calibration for each year
for year in range(scores_np.shape[1]):
year_key = f"Year{year + 1}"
if year_key in self.calibrator:
# Apply calibration transformation
calibrated[:, year] = self._apply_calibration(
scores_np[:, year],
self.calibrator[year_key]
)
else:
calibrated[:, year] = scores_np[:, year]
return torch.from_numpy(calibrated).to(scores.device)
def _apply_calibration(self, scores: np.ndarray, calibrator_params: Dict) -> np.ndarray:
"""Apply specific calibration transformation"""
# Simplified calibration - in practice, this would use the full calibration model
# from the original Sybil implementation
return scores # Placeholder for now
def forward(
self,
pixel_values: torch.FloatTensor,
return_attentions: bool = False,
return_dict: bool = True,
) -> SybilOutput:
"""
Forward pass of the Sybil model.
Args:
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, depth, height, width)`):
Pixel values of CT scan volumes.
return_attentions: (`bool`, *optional*, defaults to `False`):
Whether to return attention weights.
return_dict: (`bool`, *optional*, defaults to `True`):
Whether to return a `SybilOutput` instead of a plain tuple.
Returns:
`SybilOutput` or tuple
"""
# Extract features using 3D CNN backbone
features = self.image_encoder(pixel_values)
# Apply multi-attention pooling
pool_output = self.pool(features)
# Apply ReLU and dropout
hidden = self.relu(pool_output['hidden'])
hidden = self.dropout(hidden)
# Predict risk scores
risk_logits = self.prob_of_failure_layer(hidden)
risk_scores = torch.sigmoid(risk_logits)
# Apply calibration if available
risk_scores = self._calibrate_scores(risk_scores)
if not return_dict:
outputs = (risk_scores,)
if return_attentions:
outputs = outputs + (pool_output.get('image_attention_1'),
pool_output.get('volume_attention_1'))
return outputs
return SybilOutput(
risk_scores=risk_scores,
image_attention=pool_output.get('image_attention_1') if return_attentions else None,
volume_attention=pool_output.get('volume_attention_1') if return_attentions else None,
hidden_states=hidden if return_attentions else None
)
@classmethod
def from_pretrained_ensemble(
cls,
pretrained_model_name_or_path,
checkpoint_paths: List[str],
calibrator_path: Optional[str] = None,
**kwargs
):
"""
Load an ensemble of Sybil models from checkpoints.
Args:
pretrained_model_name_or_path: Path to the pretrained model or model identifier.
checkpoint_paths: List of paths to individual model checkpoints.
calibrator_path: Path to calibration data.
**kwargs: Additional keyword arguments for model initialization.
Returns:
SybilEnsemble: An ensemble of Sybil models.
"""
config = kwargs.pop("config", None)
if config is None:
config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
# Load calibrator if provided
calibrator_data = None
if calibrator_path:
import json
with open(calibrator_path, 'r') as f:
calibrator_data = json.load(f)
config.calibrator_data = calibrator_data
# Create ensemble
models = []
for checkpoint_path in checkpoint_paths:
model = cls(config)
# Load checkpoint weights
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Remove 'model.' prefix from state dict keys if present
state_dict = {}
for k, v in checkpoint['state_dict'].items():
if k.startswith('model.'):
state_dict[k[6:]] = v
else:
state_dict[k] = v
# Map to new model structure
mapped_state_dict = model._map_checkpoint_weights(state_dict)
model.load_state_dict(mapped_state_dict, strict=False)
models.append(model)
return SybilEnsemble(models, config)
def _map_checkpoint_weights(self, state_dict: Dict) -> Dict:
"""Map original Sybil checkpoint weights to new structure"""
mapped = {}
# Map encoder weights
for k, v in state_dict.items():
if k.startswith('image_encoder'):
mapped[k] = v
elif k.startswith('pool'):
# Map pooling layer weights
mapped[k] = v
elif k.startswith('prob_of_failure_layer'):
# Map final prediction layer
mapped[k] = v
return mapped
class SybilEnsemble:
"""Ensemble of Sybil models for improved predictions"""
def __init__(self, models: List[SybilForRiskPrediction], config: SybilConfig):
self.models = models
self.config = config
self.device = None
def to(self, device):
"""Move all models to device"""
self.device = device
for model in self.models:
model.to(device)
return self
def eval(self):
"""Set all models to evaluation mode"""
for model in self.models:
model.eval()
def __call__(
self,
pixel_values: torch.FloatTensor,
return_attentions: bool = False,
) -> SybilOutput:
"""
Run inference with ensemble voting.
Args:
pixel_values: Input CT scan volumes.
return_attentions: Whether to return attention maps.
Returns:
SybilOutput with averaged predictions from all models.
"""
all_risk_scores = []
all_image_attentions = []
all_volume_attentions = []
with torch.no_grad():
for model in self.models:
output = model(
pixel_values=pixel_values,
return_attentions=return_attentions
)
all_risk_scores.append(output.risk_scores)
if return_attentions:
all_image_attentions.append(output.image_attention)
all_volume_attentions.append(output.volume_attention)
# Average predictions
risk_scores = torch.stack(all_risk_scores).mean(dim=0)
# Average attentions if requested
image_attention = None
volume_attention = None
if return_attentions:
image_attention = torch.stack(all_image_attentions).mean(dim=0)
volume_attention = torch.stack(all_volume_attentions).mean(dim=0)
return SybilOutput(
risk_scores=risk_scores,
image_attention=image_attention,
volume_attention=volume_attention
)