File size: 4,513 Bytes
1206896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Simplified Hugging Face wrapper for original Sybil model
This ensures full compatibility with the original implementation
"""

import os
import sys
import json
import torch
import torch.nn as nn
from typing import Optional, List, Dict
from transformers import PreTrainedModel
from dataclasses import dataclass
from transformers.modeling_outputs import BaseModelOutput

# Add original Sybil to path
sys.path.append('/mnt/f/Projects/hfsybil/Sybil')
from sybil import Sybil as OriginalSybil
from sybil import Serie

try:
    from .configuration_sybil import SybilConfig
except ImportError:
    from configuration_sybil import SybilConfig


@dataclass
class SybilOutput(BaseModelOutput):
    """
    Output class for Sybil model.
    """
    risk_scores: torch.FloatTensor = None
    attentions: Optional[Dict] = None


class SybilHFWrapper(PreTrainedModel):
    """
    Hugging Face wrapper around the original Sybil model.
    This ensures complete compatibility while providing HF interface.
    """
    config_class = SybilConfig
    base_model_prefix = "sybil"

    def __init__(self, config: SybilConfig):
        super().__init__(config)
        self.config = config

        # Load the original Sybil model with ensemble
        checkpoint_dir = "/mnt/f/Projects/hfsybil/checkpoints"

        # Copy checkpoints to ~/.sybil if needed
        cache_dir = os.path.expanduser("~/.sybil")
        os.makedirs(cache_dir, exist_ok=True)

        # Map of checkpoint files
        checkpoint_files = {
            "28a7cd44f5bcd3e6cc760b65c7e0d54d.ckpt": "sybil_1",
            "56ce1a7d241dc342982f5466c4a9d7ef.ckpt": "sybil_2",
            "624407ef8e3a2a009f9fa51f9846fe9a.ckpt": "sybil_3",
            "64a91b25f84141d32852e75a3aec7305.ckpt": "sybil_4",
            "65fd1f04cb4c5847d86a9ed8ba31ac1a.ckpt": "sybil_5",
            "sybil_ensemble_simple_calibrator.json": "ensemble_calibrator"
        }

        # Copy checkpoint files
        for filename in checkpoint_files.keys():
            src = os.path.join(checkpoint_dir, filename)
            dst = os.path.join(cache_dir, filename)
            if os.path.exists(src) and not os.path.exists(dst):
                import shutil
                shutil.copy2(src, dst)

        # Initialize the original model
        self.sybil_model = OriginalSybil("sybil_ensemble")

    def forward(
        self,
        pixel_values: torch.FloatTensor = None,
        dicom_paths: List[str] = None,
        return_attentions: bool = False,
        **kwargs
    ) -> SybilOutput:
        """
        Forward pass using original Sybil model.

        Args:
            pixel_values: Pre-processed tensor (not used directly, for compatibility)
            dicom_paths: List of DICOM file paths
            return_attentions: Whether to return attention maps

        Returns:
            SybilOutput with risk scores and optional attentions
        """

        if dicom_paths is None:
            raise ValueError("dicom_paths must be provided")

        # Create Serie object
        serie = Serie(dicom_paths)

        # Run prediction
        prediction = self.sybil_model.predict([serie], return_attentions=return_attentions)

        # Convert to torch tensors
        risk_scores = torch.tensor(prediction.scores[0])

        return SybilOutput(
            risk_scores=risk_scores,
            attentions=prediction.attentions[0] if return_attentions else None
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """
        Load the model. Since we're using the original Sybil,
        we just need to ensure the checkpoints are available.
        """
        config = kwargs.pop("config", None)
        if config is None:
            config = SybilConfig.from_pretrained(pretrained_model_name_or_path)

        return cls(config)

    def save_pretrained(self, save_directory, **kwargs):
        """
        Save the model configuration.
        The actual model weights are handled by the original Sybil.
        """
        os.makedirs(save_directory, exist_ok=True)
        self.config.save_pretrained(save_directory)

        # Save info about checkpoint locations
        info = {
            "model_type": "sybil_wrapper",
            "checkpoint_dir": "/mnt/f/Projects/hfsybil/checkpoints",
            "note": "This model uses the original Sybil implementation"
        }

        with open(os.path.join(save_directory, "model_info.json"), "w") as f:
            json.dump(info, f, indent=2)