sybil / README.md
Aakash-Tripathi's picture
Update README.md
88d9d81 verified
---
license: mit
tags:
- medical
- cancer
- ct-scan
- risk-prediction
- healthcare
- pytorch
- vision
datasets:
- NLST
metrics:
- auc
- c-index
language:
- en
library_name: transformers
pipeline_tag: image-classification
---
# Sybil - Lung Cancer Risk Prediction
## 🎯 Model Description
Sybil is a validated deep learning model that predicts future lung cancer risk from a single low-dose chest CT (LDCT) scan. Published in the Journal of Clinical Oncology, this model can assess cancer risk over a 1-6 year timeframe.
### Key Features
- **Single Scan Analysis**: Requires only one LDCT scan
- **Multi-Year Prediction**: Provides risk scores for years 1-6
- **Validated Performance**: Tested across multiple institutions globally
- **Ensemble Approach**: Uses 5 models for robust predictions
## πŸš€ Quick Start
### Installation
```bash
pip install huggingface-hub torch torchvision pydicom
```
### Basic Usage
```python
from huggingface_hub import snapshot_download
import sys
import os
# Download model
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
# Import model
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
# Initialize
config = SybilConfig()
model = SybilHFWrapper(config)
dicom_dir = "path/to/volume"
dicom_paths = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith('.dcm')]
print(f"Found {len(dicom_paths)} DICOM files for prediction.")
# Get predictions
output = model(dicom_paths=dicom_paths)
risk_scores = output.risk_scores.numpy()
# Display results
print("\nLung Cancer Risk Predictions:")
print(f"Risk scores shape: {risk_scores.shape}")
# Handle both single and batch predictions
if risk_scores.ndim == 2:
# Batch predictions - take first sample
risk_scores = risk_scores[0]
for i, score in enumerate(risk_scores):
print(f"Year {i+1}: {float(score)}")
```
## πŸ“Š Example with Demo Data
```python
import requests
import zipfile
from io import BytesIO
import os
# Download demo DICOM files
def get_demo_data():
cache_dir = os.path.expanduser("~/.sybil_demo")
demo_dir = os.path.join(cache_dir, "sybil_demo_data")
if not os.path.exists(demo_dir):
print("Downloading demo data...")
url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&dl=1"
response = requests.get(url)
os.makedirs(cache_dir, exist_ok=True)
with zipfile.ZipFile(BytesIO(response.content)) as zf:
zf.extractall(cache_dir)
# Find DICOM files
dicom_files = []
for root, dirs, files in os.walk(cache_dir):
for file in files:
if file.endswith('.dcm'):
dicom_files.append(os.path.join(root, file))
return sorted(dicom_files)
# Run demo
from huggingface_hub import snapshot_download
import sys
# Load model
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_wrapper import SybilHFWrapper
from configuration_sybil import SybilConfig
# Initialize and predict
config = SybilConfig()
model = SybilHFWrapper(config)
dicom_files = get_demo_data()
output = model(dicom_paths=dicom_files)
# Show results
for i, score in enumerate(output.risk_scores.numpy()):
print(f"Year {i+1}: {float(score)}")
```
## πŸ”¬ Advanced Usage: Embedding Extraction
### Extract Embeddings Before Dropout Layer
You can extract 512-dimensional embedding vectors from the layer immediately before the dropout layer. This captures the learned risk features before the final prediction layer.
```python
from huggingface_hub import snapshot_download
import sys
import os
import torch
import numpy as np
# Download and setup model
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
def extract_embeddings(dicom_paths):
"""
Extract embeddings from the layer after ReLU, before Dropout.
Args:
dicom_paths: List of DICOM file paths
Returns:
numpy array of shape (512,) - averaged embeddings across ensemble
"""
# Initialize model
config = SybilConfig()
model = SybilHFWrapper(config)
# Set each model in ensemble to eval mode
for m in model.models:
m.eval()
# Storage for embeddings from each model in ensemble
all_embeddings = []
# Register hooks on each model in the ensemble
for model_idx, ensemble_model in enumerate(model.models):
embeddings_buffer = []
def create_hook(buffer):
def hook(module, input, output):
# Capture the output of ReLU layer (before dropout)
buffer.append(output.detach().cpu())
return hook
# Register hook on the ReLU layer
hook_handle = ensemble_model.relu.register_forward_hook(create_hook(embeddings_buffer))
# Run forward pass
with torch.no_grad():
_ = model(dicom_paths=dicom_paths)
# Remove hook
hook_handle.remove()
# Get the embeddings (should be shape [1, 512])
if embeddings_buffer:
embedding = embeddings_buffer[0].numpy().squeeze()
all_embeddings.append(embedding)
print(f"Model {model_idx + 1}: Embedding shape = {embedding.shape}")
# Average embeddings across ensemble
averaged_embedding = np.mean(all_embeddings, axis=0)
return averaged_embedding
# Usage
dicom_dir = "path/to/volume"
dicom_paths = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith('.dcm')]
embeddings = extract_embeddings(dicom_paths)
print(f"\nEmbedding vector shape: {embeddings.shape}")
print(f"Embedding statistics:")
print(f" Mean: {np.mean(embeddings):.6f}")
print(f" Std: {np.std(embeddings):.6f}")
print(f" Min: {np.min(embeddings):.6f}")
print(f" Max: {np.max(embeddings):.6f}")
```
## 🎯 Extracting Embeddings at Other Layers
### Available Extraction Points
The Sybil model has several key layers where you can extract intermediate representations:
```python
import torch
from huggingface_hub import snapshot_download
import sys
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
# Get first model from ensemble for demonstration
first_model = model.models[0]
# Model architecture flow:
# Input β†’ image_encoder β†’ pool β†’ relu β†’ dropout β†’ prob_of_failure_layer β†’ Output
def extract_layer_output(model, layer_name, dicom_paths):
"""
Extract output from any layer in the model.
Args:
model: SybilHFWrapper model
layer_name: Name of the layer to extract from
dicom_paths: List of DICOM file paths
Returns:
Extracted features from the specified layer
"""
features = []
def hook_fn(module, input, output):
features.append(output.detach().cpu())
# Register hook on the specified layer
for m in model.models:
layer = dict(m.named_modules())[layer_name]
hook_handle = layer.register_forward_hook(hook_fn)
# Run forward pass
with torch.no_grad():
_ = model(dicom_paths=dicom_paths)
# Remove hook
hook_handle.remove()
return features
# Example 1: Extract from image encoder (3D feature maps)
# Shape: (batch, 512, time, height, width)
encoder_features = extract_layer_output(model, 'image_encoder', dicom_paths)
print(f"Image encoder output shape: {encoder_features[0].shape}")
# Example 2: Extract from pooling layer (before ReLU)
# Shape: (batch, 512)
pool_features = extract_layer_output(model, 'pool', dicom_paths)
print(f"Pool layer output shape: {pool_features[0].shape}")
# Example 3: Extract from ReLU layer (before dropout) - RECOMMENDED
# Shape: (batch, 512)
relu_features = extract_layer_output(model, 'relu', dicom_paths)
print(f"ReLU layer output shape: {relu_features[0].shape}")
# Example 4: Extract from dropout layer (before final prediction)
# Shape: (batch, 512)
dropout_features = extract_layer_output(model, 'dropout', dicom_paths)
print(f"Dropout layer output shape: {dropout_features[0].shape}")
```
### Custom Layer Extraction Template
```python
def extract_custom_layer(dicom_paths, target_layer_name):
"""
Template for extracting features from any layer.
Args:
dicom_paths: List of DICOM file paths
target_layer_name: Name of target layer (e.g., 'relu', 'pool', 'image_encoder')
Returns:
Extracted features averaged across ensemble
"""
from huggingface_hub import snapshot_download
import sys
import torch
import numpy as np
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
all_features = []
for ensemble_model in model.models:
ensemble_model.eval()
features_buffer = []
# Get the target layer
target_layer = dict(ensemble_model.named_modules())[target_layer_name]
# Register hook
def hook(module, input, output):
features_buffer.append(output.detach().cpu())
hook_handle = target_layer.register_forward_hook(hook)
# Forward pass
with torch.no_grad():
_ = model(dicom_paths=dicom_paths)
hook_handle.remove()
if features_buffer:
all_features.append(features_buffer[0])
# Average across ensemble
averaged_features = torch.stack(all_features).mean(dim=0)
return averaged_features.numpy()
```
## πŸ” Model Architecture Inspection
### Print Full Model Architecture
```python
from huggingface_hub import snapshot_download
import sys
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
# Print configuration
print("=" * 80)
print("MODEL CONFIGURATION:")
print("=" * 80)
print(config)
# Print ensemble information
print("\n" + "=" * 80)
print("ENSEMBLE INFORMATION:")
print("=" * 80)
print(f"Number of models in ensemble: {len(model.models)}")
print(f"Device: {model.device}")
# Print architecture of first model
print("\n" + "=" * 80)
print("MODEL ARCHITECTURE (First model in ensemble):")
print("=" * 80)
first_model = model.models[0]
print(first_model)
```
### Count Model Parameters
```python
from huggingface_hub import snapshot_download
import sys
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
print("=" * 80)
print("MODEL PARAMETERS:")
print("=" * 80)
# Parameters per model in ensemble
for i, ensemble_model in enumerate(model.models):
total_params = sum(p.numel() for p in ensemble_model.parameters())
trainable_params = sum(p.numel() for p in ensemble_model.parameters() if p.requires_grad)
print(f"\nModel {i+1}:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Non-trainable parameters: {total_params - trainable_params:,}")
# Total ensemble parameters
total_ensemble = sum(
sum(p.numel() for p in m.parameters())
for m in model.models
)
print(f"\nTotal ensemble parameters: {total_ensemble:,}")
```
### List Model Components
```python
from huggingface_hub import snapshot_download
import sys
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
first_model = model.models[0]
print("=" * 80)
print("MODEL COMPONENTS:")
print("=" * 80)
# Print each component with parameter count
for name, module in first_model.named_children():
num_params = sum(p.numel() for p in module.parameters())
print(f"{name}: {module.__class__.__name__} ({num_params:,} parameters)")
print("\n" + "=" * 80)
print("DETAILED LAYER NAMES:")
print("=" * 80)
# Print all named modules (including nested layers)
for name, module in first_model.named_modules():
if name: # Skip the root module
print(f" {name}: {module.__class__.__name__}")
```
### Model Architecture Overview
The Sybil model consists of the following key components:
```
Input (3D CT Volume)
↓
image_encoder (R3D-18 backbone)
- 3D convolutional neural network
- Pretrained on Kinetics-400
- Output: (batch, 512, time, height, width)
↓
pool (MultiAttentionPool)
- Attention-based pooling mechanisms
- Combines multiple pooling strategies
- Output: (batch, 512)
↓
relu (ReLU activation)
- Non-linear activation
- Output: (batch, 512) ← EMBEDDING EXTRACTION POINT
↓
dropout (Dropout layer)
- Regularization (p=0.0 in inference)
- Output: (batch, 512)
↓
prob_of_failure_layer (CumulativeProbabilityLayer)
- Hazard function prediction
- Output: (batch, 6) - one score per year
↓
sigmoid (applied post-forward)
↓
Risk Scores (final output)
```
### Get Layer-by-Layer Summary
```python
def print_model_summary(model):
"""Print a detailed summary of the model architecture."""
from huggingface_hub import snapshot_download
import sys
model_path = snapshot_download(repo_id="Lab-Rasool/sybil")
sys.path.append(model_path)
from modeling_sybil_hf import SybilHFWrapper
from configuration_sybil import SybilConfig
config = SybilConfig()
model = SybilHFWrapper(config)
first_model = model.models[0]
print(f"{'Layer Name':<40} {'Type':<30} {'Parameters':>15}")
print("=" * 85)
total_params = 0
for name, module in first_model.named_modules():
if name: # Skip root
num_params = sum(p.numel() for p in module.parameters())
if num_params > 0:
print(f"{name:<40} {module.__class__.__name__:<30} {num_params:>15,}")
total_params += num_params
print("=" * 85)
print(f"{'TOTAL':<40} {'':<30} {total_params:>15,}")
# Usage
print_model_summary(model)
```
## πŸ“ˆ Performance Metrics
| Dataset | 1-Year AUC | 6-Year AUC | Sample Size |
|---------|------------|------------|-------------|
| NLST Test | 0.94 | 0.86 | ~15,000 |
| MGH | 0.86 | 0.75 | ~12,000 |
| CGMH Taiwan | 0.94 | 0.80 | ~8,000 |
## πŸ₯ Intended Use
### Primary Use Cases
- Risk stratification in lung cancer screening programs
- Research on lung cancer prediction models
- Clinical decision support (with appropriate oversight)
### Users
- Healthcare providers
- Medical researchers
- Screening program coordinators
### Out of Scope
- ❌ Diagnosis of existing cancer
- ❌ Use with non-LDCT imaging (X-rays, MRI)
- ❌ Sole basis for clinical decisions
- ❌ Use outside medical supervision
## πŸ“‹ Input Requirements
- **Format**: DICOM files from chest CT scan
- **Type**: Low-dose CT (LDCT)
- **Orientation**: Axial view
- **Order**: Anatomically ordered (abdomen β†’ clavicles)
- **Number of slices**: Typically 100-300 slices
- **Resolution**: Automatically handled by model
## ⚠️ Important Considerations
### Medical AI Notice
This model should **supplement, not replace**, clinical judgment. Always consider:
- Complete patient medical history
- Additional risk factors (smoking, family history)
- Current clinical guidelines
- Need for professional medical oversight
### Limitations
- Optimized for screening population (ages 55-80)
- Best performance with LDCT scans
- Not validated for pediatric use
- Performance may vary with different scanner manufacturers
## πŸ“š Citation
If you use this model, please cite the original paper:
```bibtex
@article{mikhael2023sybil,
title={Sybil: a validated deep learning model to predict future lung cancer risk from a single low-dose chest computed tomography},
author={Mikhael, Peter G and Wohlwend, Jeremy and Yala, Adam and others},
journal={Journal of Clinical Oncology},
volume={41},
number={12},
pages={2191--2200},
year={2023},
publisher={American Society of Clinical Oncology}
}
```
## πŸ™ Acknowledgments
This Hugging Face implementation is based on the original work by:
- **Original Authors**: Peter G. Mikhael & Jeremy Wohlwend
- **Institutions**: MIT CSAIL & Massachusetts General Hospital
- **Original Repository**: [GitHub](https://github.com/reginabarzilaygroup/Sybil)
- **Paper**: [Journal of Clinical Oncology](https://doi.org/10.1200/JCO.22.01345)
## πŸ“„ License
MIT License - See [LICENSE](LICENSE) file
- Original Model Β© 2022 Peter Mikhael & Jeremy Wohlwend
- HF Adaptation with Embeddings Β© 2025 [Aakash Tripathi](https://github.com/Aakash-Tripathi)
## πŸ”§ Troubleshooting
### Common Issues
1. **Import Error**: Make sure to append model path to sys.path
```python
sys.path.append(model_path)
```
2. **Missing Dependencies**: Install all requirements
```bash
pip install torch torchvision pydicom sybil huggingface-hub
```
3. **DICOM Loading Error**: Ensure DICOM files are valid CT scans
```python
import pydicom
dcm = pydicom.dcmread("your_file.dcm") # Test single file
```
4. **Memory Issues**: Model requires ~4GB GPU memory
```python
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
```
## πŸ“¬ Support
- **HF Model Issues**: Open issue on this repository
- **Original Model**: [GitHub Issues](https://github.com/reginabarzilaygroup/Sybil/issues)
- **Medical Questions**: Consult healthcare professionals
## πŸ” Additional Resources
- [Original GitHub Repository](https://github.com/reginabarzilaygroup/Sybil)
- [Paper (Open Access)](https://doi.org/10.1200/JCO.22.01345)
- [NLST Dataset Information](https://cdas.cancer.gov/nlst/)
- [Demo Data](https://github.com/reginabarzilaygroup/Sybil/releases)
---
**Note**: This is a research model. Always consult qualified healthcare professionals for medical decisions.