|
|
""" |
|
|
Simplified Hugging Face wrapper for original Sybil model |
|
|
This ensures full compatibility with the original implementation |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, List, Dict |
|
|
from transformers import PreTrainedModel |
|
|
from dataclasses import dataclass |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
sys.path.append('/mnt/f/Projects/hfsybil/Sybil') |
|
|
from sybil import Sybil as OriginalSybil |
|
|
from sybil import Serie |
|
|
|
|
|
try: |
|
|
from .configuration_sybil import SybilConfig |
|
|
except ImportError: |
|
|
from configuration_sybil import SybilConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SybilOutput(BaseModelOutput): |
|
|
""" |
|
|
Output class for Sybil model. |
|
|
""" |
|
|
risk_scores: torch.FloatTensor = None |
|
|
attentions: Optional[Dict] = None |
|
|
|
|
|
|
|
|
class SybilHFWrapper(PreTrainedModel): |
|
|
""" |
|
|
Hugging Face wrapper around the original Sybil model. |
|
|
This ensures complete compatibility while providing HF interface. |
|
|
""" |
|
|
config_class = SybilConfig |
|
|
base_model_prefix = "sybil" |
|
|
|
|
|
def __init__(self, config: SybilConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints" |
|
|
|
|
|
|
|
|
cache_dir = os.path.expanduser("~/.sybil") |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
checkpoint_files = { |
|
|
"28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1", |
|
|
"56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2", |
|
|
"624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3", |
|
|
"64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4", |
|
|
"65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5", |
|
|
"sybil_ensemble_simple_calibrator.json": "ensemble_calibrator" |
|
|
} |
|
|
|
|
|
|
|
|
for filename in checkpoint_files.keys(): |
|
|
src = os.path.join(checkpoint_dir, filename) |
|
|
dst = os.path.join(cache_dir, filename) |
|
|
if os.path.exists(src) and not os.path.exists(dst): |
|
|
import shutil |
|
|
shutil.copy2(src, dst) |
|
|
|
|
|
|
|
|
self.sybil_model = OriginalSybil("sybil_ensemble") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor = None, |
|
|
dicom_paths: List[str] = None, |
|
|
return_attentions: bool = False, |
|
|
**kwargs |
|
|
) -> SybilOutput: |
|
|
""" |
|
|
Forward pass using original Sybil model. |
|
|
|
|
|
Args: |
|
|
pixel_values: Pre-processed tensor (not used directly, for compatibility) |
|
|
dicom_paths: List of DICOM file paths |
|
|
return_attentions: Whether to return attention maps |
|
|
|
|
|
Returns: |
|
|
SybilOutput with risk scores and optional attentions |
|
|
""" |
|
|
|
|
|
if dicom_paths is None: |
|
|
raise ValueError("dicom_paths must be provided") |
|
|
|
|
|
|
|
|
serie = Serie(dicom_paths) |
|
|
|
|
|
|
|
|
prediction = self.sybil_model.predict([serie], return_attentions=return_attentions) |
|
|
|
|
|
|
|
|
risk_scores = torch.tensor(prediction.scores[0]) |
|
|
|
|
|
return SybilOutput( |
|
|
risk_scores=risk_scores, |
|
|
attentions=prediction.attentions[0] if return_attentions else None |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
|
""" |
|
|
Load the model. Since we're using the original Sybil, |
|
|
we just need to ensure the checkpoints are available. |
|
|
""" |
|
|
config = kwargs.pop("config", None) |
|
|
if config is None: |
|
|
config = SybilConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
return cls(config) |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
""" |
|
|
Save the model configuration. |
|
|
The actual model weights are handled by the original Sybil. |
|
|
""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
|
|
|
info = { |
|
|
"model_type": "sybil_wrapper", |
|
|
"checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints", |
|
|
"note": "This model uses the original Sybil implementation" |
|
|
} |
|
|
|
|
|
with open(os.path.join(save_directory, "model_info.json"), "w") as f: |
|
|
json.dump(info, f, indent=2) |