sybil / modeling_sybil_wrapper.py
Aakash-Tripathi's picture
Upload folder using huggingface_hub
1206896 verified
raw
history blame
4.51 kB
"""
Simplified Hugging Face wrapper for original Sybil model
This ensures full compatibility with the original implementation
"""
import os
import sys
import json
import torch
import torch.nn as nn
from typing import Optional, List, Dict
from transformers import PreTrainedModel
from dataclasses import dataclass
from transformers.modeling_outputs import BaseModelOutput
# Add original Sybil to path
sys.path.append('/mnt/f/Projects/hfsybil/Sybil')
from sybil import Sybil as OriginalSybil
from sybil import Serie
try:
from .configuration_sybil import SybilConfig
except ImportError:
from configuration_sybil import SybilConfig
@dataclass
class SybilOutput(BaseModelOutput):
"""
Output class for Sybil model.
"""
risk_scores: torch.FloatTensor = None
attentions: Optional[Dict] = None
class SybilHFWrapper(PreTrainedModel):
"""
Hugging Face wrapper around the original Sybil model.
This ensures complete compatibility while providing HF interface.
"""
config_class = SybilConfig
base_model_prefix = "sybil"
def __init__(self, config: SybilConfig):
super().__init__(config)
self.config = config
# Load the original Sybil model with ensemble
checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints"
# Copy checkpoints to ~/.sybil if needed
cache_dir = os.path.expanduser("~/.sybil")
os.makedirs(cache_dir, exist_ok=True)
# Map of checkpoint files
checkpoint_files = {
"28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1",
"56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2",
"624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3",
"64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4",
"65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5",
"sybil_ensemble_simple_calibrator.json": "ensemble_calibrator"
}
# Copy checkpoint files
for filename in checkpoint_files.keys():
src = os.path.join(checkpoint_dir, filename)
dst = os.path.join(cache_dir, filename)
if os.path.exists(src) and not os.path.exists(dst):
import shutil
shutil.copy2(src, dst)
# Initialize the original model
self.sybil_model = OriginalSybil("sybil_ensemble")
def forward(
self,
pixel_values: torch.FloatTensor = None,
dicom_paths: List[str] = None,
return_attentions: bool = False,
**kwargs
) -> SybilOutput:
"""
Forward pass using original Sybil model.
Args:
pixel_values: Pre-processed tensor (not used directly, for compatibility)
dicom_paths: List of DICOM file paths
return_attentions: Whether to return attention maps
Returns:
SybilOutput with risk scores and optional attentions
"""
if dicom_paths is None:
raise ValueError("dicom_paths must be provided")
# Create Serie object
serie = Serie(dicom_paths)
# Run prediction
prediction = self.sybil_model.predict([serie], return_attentions=return_attentions)
# Convert to torch tensors
risk_scores = torch.tensor(prediction.scores[0])
return SybilOutput(
risk_scores=risk_scores,
attentions=prediction.attentions[0] if return_attentions else None
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the model. Since we're using the original Sybil,
we just need to ensure the checkpoints are available.
"""
config = kwargs.pop("config", None)
if config is None:
config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
return cls(config)
def save_pretrained(self, save_directory, **kwargs):
"""
Save the model configuration.
The actual model weights are handled by the original Sybil.
"""
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
# Save info about checkpoint locations
info = {
"model_type": "sybil_wrapper",
"checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints",
"note": "This model uses the original Sybil implementation"
}
with open(os.path.join(save_directory, "model_info.json"), "w") as f:
json.dump(info, f, indent=2)