File size: 4,513 Bytes
1206896 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""
Simplified Hugging Face wrapper for original Sybil model
This ensures full compatibility with the original implementation
"""
import os
import sys
import json
import torch
import torch.nn as nn
from typing import Optional, List, Dict
from transformers import PreTrainedModel
from dataclasses import dataclass
from transformers.modeling_outputs import BaseModelOutput
# Add original Sybil to path
sys.path.append('/mnt/f/Projects/hfsybil/Sybil')
from sybil import Sybil as OriginalSybil
from sybil import Serie
try:
from .configuration_sybil import SybilConfig
except ImportError:
from configuration_sybil import SybilConfig
@dataclass
class SybilOutput(BaseModelOutput):
"""
Output class for Sybil model.
"""
risk_scores: torch.FloatTensor = None
attentions: Optional[Dict] = None
class SybilHFWrapper(PreTrainedModel):
"""
Hugging Face wrapper around the original Sybil model.
This ensures complete compatibility while providing HF interface.
"""
config_class = SybilConfig
base_model_prefix = "sybil"
def __init__(self, config: SybilConfig):
super().__init__(config)
self.config = config
# Load the original Sybil model with ensemble
checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints"
# Copy checkpoints to ~/.sybil if needed
cache_dir = os.path.expanduser("~/.sybil")
os.makedirs(cache_dir, exist_ok=True)
# Map of checkpoint files
checkpoint_files = {
"28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1",
"56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2",
"624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3",
"64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4",
"65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5",
"sybil_ensemble_simple_calibrator.json": "ensemble_calibrator"
}
# Copy checkpoint files
for filename in checkpoint_files.keys():
src = os.path.join(checkpoint_dir, filename)
dst = os.path.join(cache_dir, filename)
if os.path.exists(src) and not os.path.exists(dst):
import shutil
shutil.copy2(src, dst)
# Initialize the original model
self.sybil_model = OriginalSybil("sybil_ensemble")
def forward(
self,
pixel_values: torch.FloatTensor = None,
dicom_paths: List[str] = None,
return_attentions: bool = False,
**kwargs
) -> SybilOutput:
"""
Forward pass using original Sybil model.
Args:
pixel_values: Pre-processed tensor (not used directly, for compatibility)
dicom_paths: List of DICOM file paths
return_attentions: Whether to return attention maps
Returns:
SybilOutput with risk scores and optional attentions
"""
if dicom_paths is None:
raise ValueError("dicom_paths must be provided")
# Create Serie object
serie = Serie(dicom_paths)
# Run prediction
prediction = self.sybil_model.predict([serie], return_attentions=return_attentions)
# Convert to torch tensors
risk_scores = torch.tensor(prediction.scores[0])
return SybilOutput(
risk_scores=risk_scores,
attentions=prediction.attentions[0] if return_attentions else None
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the model. Since we're using the original Sybil,
we just need to ensure the checkpoints are available.
"""
config = kwargs.pop("config", None)
if config is None:
config = SybilConfig.from_pretrained(pretrained_model_name_or_path)
return cls(config)
def save_pretrained(self, save_directory, **kwargs):
"""
Save the model configuration.
The actual model weights are handled by the original Sybil.
"""
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
# Save info about checkpoint locations
info = {
"model_type": "sybil_wrapper",
"checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints",
"note": "This model uses the original Sybil implementation"
}
with open(os.path.join(save_directory, "model_info.json"), "w") as f:
json.dump(info, f, indent=2) |