""" Simplified Hugging Face wrapper for original Sybil model This ensures full compatibility with the original implementation """ import os import sys import json import torch import torch.nn as nn from typing import Optional, List, Dict from transformers import PreTrainedModel from dataclasses import dataclass from transformers.modeling_outputs import BaseModelOutput # Add original Sybil to path sys.path.append('/mnt/f/Projects/hfsybil/Sybil') from sybil import Sybil as OriginalSybil from sybil import Serie try: from .configuration_sybil import SybilConfig except ImportError: from configuration_sybil import SybilConfig @dataclass class SybilOutput(BaseModelOutput): """ Output class for Sybil model. """ risk_scores: torch.FloatTensor = None attentions: Optional[Dict] = None class SybilHFWrapper(PreTrainedModel): """ Hugging Face wrapper around the original Sybil model. This ensures complete compatibility while providing HF interface. """ config_class = SybilConfig base_model_prefix = "sybil" def __init__(self, config: SybilConfig): super().__init__(config) self.config = config # Load the original Sybil model with ensemble checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints" # Copy checkpoints to ~/.sybil if needed cache_dir = os.path.expanduser("~/.sybil") os.makedirs(cache_dir, exist_ok=True) # Map of checkpoint files checkpoint_files = { "28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1", "56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2", "624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3", "64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4", "65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5", "sybil_ensemble_simple_calibrator.json": "ensemble_calibrator" } # Copy checkpoint files for filename in checkpoint_files.keys(): src = os.path.join(checkpoint_dir, filename) dst = os.path.join(cache_dir, filename) if os.path.exists(src) and not os.path.exists(dst): import shutil shutil.copy2(src, dst) # Initialize the original model self.sybil_model = OriginalSybil("sybil_ensemble") def forward( self, pixel_values: torch.FloatTensor = None, dicom_paths: List[str] = None, return_attentions: bool = False, **kwargs ) -> SybilOutput: """ Forward pass using original Sybil model. Args: pixel_values: Pre-processed tensor (not used directly, for compatibility) dicom_paths: List of DICOM file paths return_attentions: Whether to return attention maps Returns: SybilOutput with risk scores and optional attentions """ if dicom_paths is None: raise ValueError("dicom_paths must be provided") # Create Serie object serie = Serie(dicom_paths) # Run prediction prediction = self.sybil_model.predict([serie], return_attentions=return_attentions) # Convert to torch tensors risk_scores = torch.tensor(prediction.scores[0]) return SybilOutput( risk_scores=risk_scores, attentions=prediction.attentions[0] if return_attentions else None ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Load the model. Since we're using the original Sybil, we just need to ensure the checkpoints are available. """ config = kwargs.pop("config", None) if config is None: config = SybilConfig.from_pretrained(pretrained_model_name_or_path) return cls(config) def save_pretrained(self, save_directory, **kwargs): """ Save the model configuration. The actual model weights are handled by the original Sybil. """ os.makedirs(save_directory, exist_ok=True) self.config.save_pretrained(save_directory) # Save info about checkpoint locations info = { "model_type": "sybil_wrapper", "checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints", "note": "This model uses the original Sybil implementation" } with open(os.path.join(save_directory, "model_info.json"), "w") as f: json.dump(info, f, indent=2)