File size: 12,633 Bytes
cf14762 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
"""
Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model.
This version works directly from HF without requiring external Sybil package.
"""
import os
import json
import sys
import torch
import numpy as np
from typing import List, Dict, Optional
from dataclasses import dataclass
from transformers.modeling_outputs import BaseModelOutput
from safetensors.torch import load_file
# Add model path to sys.path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
try:
from .configuration_sybil import SybilConfig
from .modeling_sybil import SybilForRiskPrediction
from .image_processing_sybil import SybilImageProcessor
except ImportError:
from configuration_sybil import SybilConfig
from modeling_sybil import SybilForRiskPrediction
from image_processing_sybil import SybilImageProcessor
@dataclass
class SybilOutput(BaseModelOutput):
"""
Output class for Sybil model predictions.
Args:
risk_scores: Risk scores for each year (1-6 years by default)
attentions: Optional attention maps if requested
"""
risk_scores: torch.FloatTensor = None
attentions: Optional[Dict] = None
class SybilHFWrapper:
"""
Hugging Face wrapper for Sybil ensemble model.
Provides a simple interface for lung cancer risk prediction from CT scans.
"""
def __init__(self, config: SybilConfig = None, model_dir: str = None):
"""
Initialize the Sybil model ensemble.
Args:
config: Model configuration (will use default if not provided)
model_dir: Directory containing model files (defaults to file location)
"""
self.config = config if config is not None else SybilConfig()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get the directory where model files are located
if model_dir is not None:
self.model_dir = model_dir
else:
# Default to where this file is located
self.model_dir = os.path.dirname(os.path.abspath(__file__))
# Initialize image processor
self.image_processor = SybilImageProcessor()
# Load calibrator
self.calibrator = self._load_calibrator()
# Load ensemble models
self.models = self._load_ensemble_models()
def _load_calibrator(self) -> Dict:
"""Load ensemble calibrator data"""
calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json")
if os.path.exists(calibrator_path):
with open(calibrator_path, 'r') as f:
return json.load(f)
else:
# Try alternative location
calibrator_path = os.path.join(self.model_dir, "calibrator_data.json")
if os.path.exists(calibrator_path):
with open(calibrator_path, 'r') as f:
return json.load(f)
return {}
def _load_ensemble_models(self) -> List[torch.nn.Module]:
"""
Load all models in the ensemble from original checkpoints.
Note: We load from .ckpt files instead of safetensors because the safetensors
were created with the wrong CumulativeProbabilityLayer architecture.
"""
import glob as glob_module
models = []
# Find all .ckpt files in checkpoints directory
checkpoints_dir = os.path.join(self.model_dir, "checkpoints")
checkpoint_files = sorted(glob_module.glob(os.path.join(checkpoints_dir, "*.ckpt")))
print(f"Found {len(checkpoint_files)} checkpoint files")
# Load each checkpoint file
for checkpoint_path in checkpoint_files:
try:
model = SybilForRiskPrediction(self.config)
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
# Extract state dict
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# Remove 'model.' prefix if present
cleaned_state_dict = {}
for k, v in state_dict.items():
if k.startswith('model.'):
cleaned_state_dict[k[6:]] = v
else:
cleaned_state_dict[k] = v
# Load weights
model.load_state_dict(cleaned_state_dict, strict=False)
model.to(self.device)
model.eval()
models.append(model)
print(f" Loaded model from {os.path.basename(checkpoint_path)}")
except Exception as e:
print(f" Warning: Could not load {os.path.basename(checkpoint_path)}: {e}")
continue
if not models:
raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.")
print(f"Loaded {len(models)} models in ensemble")
return models
def _apply_calibration(self, scores: np.ndarray) -> np.ndarray:
"""
Apply complete isotonic regression calibration matching the original Sybil implementation.
This method applies the same calibration as the original SimpleClassifierGroup.predict_proba:
1. For each year, apply each calibrator in the ensemble
2. Each calibrator applies: linear transform -> clip -> isotonic regression (np.interp)
3. Average predictions from all calibrators
Args:
scores: Raw risk scores from the model (shape: [batch_size, num_years])
Returns:
Calibrated risk scores (shape: [batch_size, num_years])
"""
if not self.calibrator:
return scores
calibrated_scores = []
for year in range(scores.shape[1]):
year_key = f"Year{year + 1}"
if year_key not in self.calibrator:
# No calibrator for this year, use raw scores
calibrated_scores.append(scores[:, year])
continue
cal_list = self.calibrator[year_key]
if not isinstance(cal_list, list) or len(cal_list) == 0:
# Invalid calibrator format, use raw scores
calibrated_scores.append(scores[:, year])
continue
# Apply each calibrator and collect predictions
year_predictions = []
for cal_data in cal_list:
if not isinstance(cal_data, dict):
continue
# Extract calibration parameters
if "coef" not in cal_data or "intercept" not in cal_data:
continue
coef = np.array(cal_data["coef"]) # Shape: [[scalar]]
intercept = np.array(cal_data["intercept"]) # Shape: [scalar]
# Extract isotonic regression points
if "x0" not in cal_data or "y0" not in cal_data:
continue
x0 = np.array(cal_data["x0"])
y0 = np.array(cal_data["y0"])
# Extract clipping bounds
x_min = cal_data.get("x_min", -np.inf)
x_max = cal_data.get("x_max", np.inf)
# Apply complete calibration pipeline:
# Step 1: Linear transformation
probs = scores[:, year].reshape(-1, 1) # Shape: [batch_size, 1]
T = probs @ coef + intercept # Matrix multiplication
T = T.flatten() # Shape: [batch_size]
# Step 2: Clip to valid range
T = np.clip(T, x_min, x_max)
# Step 3: Apply isotonic regression via interpolation
# This is the CRITICAL step that was missing!
calibrated = np.interp(T, x0, y0)
year_predictions.append(calibrated)
if len(year_predictions) == 0:
# No valid calibrators, use raw scores
calibrated_scores.append(scores[:, year])
else:
# Average predictions from all calibrators (like SimpleClassifierGroup)
calibrated_scores.append(np.mean(year_predictions, axis=0))
return np.stack(calibrated_scores, axis=1)
def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor:
"""
Preprocess DICOM files for model input.
Args:
dicom_paths: List of paths to DICOM files
Returns:
Preprocessed tensor ready for model input
"""
# Use the image processor to handle DICOM files
result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt")
pixel_values = result["pixel_values"]
# Ensure we have 5D tensor (B, C, D, H, W)
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(0) # Add batch dimension
return pixel_values.to(self.device)
def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput:
"""
Run prediction on a CT scan series.
Args:
dicom_paths: List of paths to DICOM files for a single CT series
return_attentions: Whether to return attention maps
Returns:
SybilOutput with risk scores and optional attention maps
"""
# Preprocess the DICOM files
pixel_values = self.preprocess_dicom(dicom_paths)
# Run inference with ensemble
all_predictions = []
all_attentions = []
with torch.no_grad():
for model in self.models:
output = model(
pixel_values=pixel_values,
return_attentions=return_attentions
)
# Extract risk scores
if hasattr(output, 'risk_scores'):
predictions = output.risk_scores
else:
predictions = output[0] if isinstance(output, tuple) else output
all_predictions.append(predictions.cpu().numpy())
if return_attentions and hasattr(output, 'image_attention'):
all_attentions.append(output.image_attention)
# Average ensemble predictions
ensemble_pred = np.mean(all_predictions, axis=0)
# Apply calibration
calibrated_pred = self._apply_calibration(ensemble_pred)
# Convert back to torch tensor
risk_scores = torch.from_numpy(calibrated_pred).float()
# Average attentions if requested
attentions = None
if return_attentions and all_attentions:
attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)}
return SybilOutput(risk_scores=risk_scores, attentions=attentions)
def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput:
"""
Convenience method for prediction.
Args:
dicom_paths: List of DICOM file paths for a single series
dicom_series: List of lists of DICOM paths for batch processing
**kwargs: Additional arguments passed to predict()
Returns:
SybilOutput with predictions
"""
if dicom_series is not None:
# Batch processing
all_outputs = []
for paths in dicom_series:
output = self.predict(paths, **kwargs)
all_outputs.append(output.risk_scores)
risk_scores = torch.stack(all_outputs)
return SybilOutput(risk_scores=risk_scores)
elif dicom_paths is not None:
return self.predict(dicom_paths, **kwargs)
else:
raise ValueError("Either dicom_paths or dicom_series must be provided")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Load model from Hugging Face hub or local path.
Args:
pretrained_model_name_or_path: HF model ID or local path
**kwargs: Additional configuration arguments
Returns:
SybilHFWrapper instance
"""
# Load configuration
config = kwargs.pop("config", None)
if config is None:
try:
config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
except:
config = SybilConfig()
return cls(config=config) |