File size: 12,633 Bytes
cf14762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model.
This version works directly from HF without requiring external Sybil package.
"""

import os
import json
import sys
import torch
import numpy as np
from typing import List, Dict, Optional
from dataclasses import dataclass
from transformers.modeling_outputs import BaseModelOutput
from safetensors.torch import load_file

# Add model path to sys.path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

try:
    from .configuration_sybil import SybilConfig
    from .modeling_sybil import SybilForRiskPrediction
    from .image_processing_sybil import SybilImageProcessor
except ImportError:
    from configuration_sybil import SybilConfig
    from modeling_sybil import SybilForRiskPrediction
    from image_processing_sybil import SybilImageProcessor


@dataclass
class SybilOutput(BaseModelOutput):
    """
    Output class for Sybil model predictions.

    Args:
        risk_scores: Risk scores for each year (1-6 years by default)
        attentions: Optional attention maps if requested
    """
    risk_scores: torch.FloatTensor = None
    attentions: Optional[Dict] = None


class SybilHFWrapper:
    """
    Hugging Face wrapper for Sybil ensemble model.
    Provides a simple interface for lung cancer risk prediction from CT scans.
    """

    def __init__(self, config: SybilConfig = None, model_dir: str = None):
        """
        Initialize the Sybil model ensemble.

        Args:
            config: Model configuration (will use default if not provided)
            model_dir: Directory containing model files (defaults to file location)
        """
        self.config = config if config is not None else SybilConfig()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Get the directory where model files are located
        if model_dir is not None:
            self.model_dir = model_dir
        else:
            # Default to where this file is located
            self.model_dir = os.path.dirname(os.path.abspath(__file__))

        # Initialize image processor
        self.image_processor = SybilImageProcessor()

        # Load calibrator
        self.calibrator = self._load_calibrator()

        # Load ensemble models
        self.models = self._load_ensemble_models()

    def _load_calibrator(self) -> Dict:
        """Load ensemble calibrator data"""
        calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json")

        if os.path.exists(calibrator_path):
            with open(calibrator_path, 'r') as f:
                return json.load(f)
        else:
            # Try alternative location
            calibrator_path = os.path.join(self.model_dir, "calibrator_data.json")
            if os.path.exists(calibrator_path):
                with open(calibrator_path, 'r') as f:
                    return json.load(f)
        return {}

    def _load_ensemble_models(self) -> List[torch.nn.Module]:
        """
        Load all models in the ensemble from original checkpoints.

        Note: We load from .ckpt files instead of safetensors because the safetensors
        were created with the wrong CumulativeProbabilityLayer architecture.
        """
        import glob as glob_module
        models = []

        # Find all .ckpt files in checkpoints directory
        checkpoints_dir = os.path.join(self.model_dir, "checkpoints")
        checkpoint_files = sorted(glob_module.glob(os.path.join(checkpoints_dir, "*.ckpt")))

        print(f"Found {len(checkpoint_files)} checkpoint files")

        # Load each checkpoint file
        for checkpoint_path in checkpoint_files:
            try:
                model = SybilForRiskPrediction(self.config)
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

                # Extract state dict
                if 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint

                # Remove 'model.' prefix if present
                cleaned_state_dict = {}
                for k, v in state_dict.items():
                    if k.startswith('model.'):
                        cleaned_state_dict[k[6:]] = v
                    else:
                        cleaned_state_dict[k] = v

                # Load weights
                model.load_state_dict(cleaned_state_dict, strict=False)
                model.to(self.device)
                model.eval()
                models.append(model)
                print(f"  Loaded model from {os.path.basename(checkpoint_path)}")
            except Exception as e:
                print(f"  Warning: Could not load {os.path.basename(checkpoint_path)}: {e}")
                continue

        if not models:
            raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.")

        print(f"Loaded {len(models)} models in ensemble")
        return models

    def _apply_calibration(self, scores: np.ndarray) -> np.ndarray:
        """
        Apply complete isotonic regression calibration matching the original Sybil implementation.

        This method applies the same calibration as the original SimpleClassifierGroup.predict_proba:
        1. For each year, apply each calibrator in the ensemble
        2. Each calibrator applies: linear transform -> clip -> isotonic regression (np.interp)
        3. Average predictions from all calibrators

        Args:
            scores: Raw risk scores from the model (shape: [batch_size, num_years])

        Returns:
            Calibrated risk scores (shape: [batch_size, num_years])
        """
        if not self.calibrator:
            return scores

        calibrated_scores = []

        for year in range(scores.shape[1]):
            year_key = f"Year{year + 1}"

            if year_key not in self.calibrator:
                # No calibrator for this year, use raw scores
                calibrated_scores.append(scores[:, year])
                continue

            cal_list = self.calibrator[year_key]

            if not isinstance(cal_list, list) or len(cal_list) == 0:
                # Invalid calibrator format, use raw scores
                calibrated_scores.append(scores[:, year])
                continue

            # Apply each calibrator and collect predictions
            year_predictions = []

            for cal_data in cal_list:
                if not isinstance(cal_data, dict):
                    continue

                # Extract calibration parameters
                if "coef" not in cal_data or "intercept" not in cal_data:
                    continue

                coef = np.array(cal_data["coef"])  # Shape: [[scalar]]
                intercept = np.array(cal_data["intercept"])  # Shape: [scalar]

                # Extract isotonic regression points
                if "x0" not in cal_data or "y0" not in cal_data:
                    continue

                x0 = np.array(cal_data["x0"])
                y0 = np.array(cal_data["y0"])

                # Extract clipping bounds
                x_min = cal_data.get("x_min", -np.inf)
                x_max = cal_data.get("x_max", np.inf)

                # Apply complete calibration pipeline:
                # Step 1: Linear transformation
                probs = scores[:, year].reshape(-1, 1)  # Shape: [batch_size, 1]
                T = probs @ coef + intercept  # Matrix multiplication
                T = T.flatten()  # Shape: [batch_size]

                # Step 2: Clip to valid range
                T = np.clip(T, x_min, x_max)

                # Step 3: Apply isotonic regression via interpolation
                # This is the CRITICAL step that was missing!
                calibrated = np.interp(T, x0, y0)

                year_predictions.append(calibrated)

            if len(year_predictions) == 0:
                # No valid calibrators, use raw scores
                calibrated_scores.append(scores[:, year])
            else:
                # Average predictions from all calibrators (like SimpleClassifierGroup)
                calibrated_scores.append(np.mean(year_predictions, axis=0))

        return np.stack(calibrated_scores, axis=1)

    def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor:
        """
        Preprocess DICOM files for model input.

        Args:
            dicom_paths: List of paths to DICOM files

        Returns:
            Preprocessed tensor ready for model input
        """
        # Use the image processor to handle DICOM files
        result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt")
        pixel_values = result["pixel_values"]

        # Ensure we have 5D tensor (B, C, D, H, W)
        if pixel_values.ndim == 4:
            pixel_values = pixel_values.unsqueeze(0)  # Add batch dimension

        return pixel_values.to(self.device)

    def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput:
        """
        Run prediction on a CT scan series.

        Args:
            dicom_paths: List of paths to DICOM files for a single CT series
            return_attentions: Whether to return attention maps

        Returns:
            SybilOutput with risk scores and optional attention maps
        """
        # Preprocess the DICOM files
        pixel_values = self.preprocess_dicom(dicom_paths)

        # Run inference with ensemble
        all_predictions = []
        all_attentions = []

        with torch.no_grad():
            for model in self.models:
                output = model(
                    pixel_values=pixel_values,
                    return_attentions=return_attentions
                )

                # Extract risk scores
                if hasattr(output, 'risk_scores'):
                    predictions = output.risk_scores
                else:
                    predictions = output[0] if isinstance(output, tuple) else output

                all_predictions.append(predictions.cpu().numpy())

                if return_attentions and hasattr(output, 'image_attention'):
                    all_attentions.append(output.image_attention)

        # Average ensemble predictions
        ensemble_pred = np.mean(all_predictions, axis=0)

        # Apply calibration
        calibrated_pred = self._apply_calibration(ensemble_pred)

        # Convert back to torch tensor
        risk_scores = torch.from_numpy(calibrated_pred).float()

        # Average attentions if requested
        attentions = None
        if return_attentions and all_attentions:
            attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)}

        return SybilOutput(risk_scores=risk_scores, attentions=attentions)

    def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput:
        """
        Convenience method for prediction.

        Args:
            dicom_paths: List of DICOM file paths for a single series
            dicom_series: List of lists of DICOM paths for batch processing
            **kwargs: Additional arguments passed to predict()

        Returns:
            SybilOutput with predictions
        """
        if dicom_series is not None:
            # Batch processing
            all_outputs = []
            for paths in dicom_series:
                output = self.predict(paths, **kwargs)
                all_outputs.append(output.risk_scores)

            risk_scores = torch.stack(all_outputs)
            return SybilOutput(risk_scores=risk_scores)
        elif dicom_paths is not None:
            return self.predict(dicom_paths, **kwargs)
        else:
            raise ValueError("Either dicom_paths or dicom_series must be provided")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """
        Load model from Hugging Face hub or local path.

        Args:
            pretrained_model_name_or_path: HF model ID or local path
            **kwargs: Additional configuration arguments

        Returns:
            SybilHFWrapper instance
        """
        # Load configuration
        config = kwargs.pop("config", None)
        if config is None:
            try:
                config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
            except:
                config = SybilConfig()

        return cls(config=config)