File size: 5,383 Bytes
26b1bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
import floret, torch
import os, shutil
from configuration_stacked import ImpressoConfig
from transformers.modeling_utils import (
    get_parameter_device as original_get_parameter_device,
)


import torch

# Import Hugging Face dependencies
import transformers.modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import (
    get_parameter_device as original_get_parameter_device,
)


# Custom get_parameter_device
def custom_get_parameter_device(module):
    """
    Custom get_parameter_device() to handle floret models.
    Returns 'cpu' for FloretModelWrapper, otherwise uses the original implementation.
    """
    # Check if the model is an instance of your FloretModelWrapper
    if isinstance(module, FloretModelWrapper):
        print(
            "Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
        )
        return torch.device("cpu")

    # Otherwise, fall back to Hugging Face's original implementation
    return original_get_parameter_device(module)


# Custom device property
@property
def custom_device(self) -> torch.device:
    """
    Custom device() method to handle floret models.
    Always returns torch.device('cpu') for FloretModelWrapper.
    """
    # Check if the model is an instance of your FloretModelWrapper
    if isinstance(self, FloretModelWrapper):
        print(
            "Custom device(): Detected FloretModelWrapper. Returning torch.device('cpu')."
        )
        return torch.device("cpu")

    # Otherwise, fall back to Hugging Face's original implementation
    return torch.device("cpu")  # original_device.__get__(self, type(self))


# Monkey-patch get_parameter_device and device property
transformers.modeling_utils.get_parameter_device = custom_get_parameter_device
PreTrainedModel.device = custom_device

print("Monkey-patch applied: get_parameter_device and device property")

# logger = logging.getLogger(__name__)

original_device = PreTrainedModel.device


def get_info(label_map):
    num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
    return num_token_labels_dict


class FloretModelWrapper:
    """
    Wrapper for floret model to make it compatible with Hugging Face pipeline.
    Mocks the .device attribute and passes predict() unchanged.
    """

    def __init__(self, floret_model):
        self.floret_model = floret_model

        # Mocking the .device attribute to make Hugging Face happy
        self.device = torch.device("cpu")  # floret is always on CPU

    def predict(self, text, k=1):
        """
        Pass-through for floret's predict() method.
        """
        return self.floret_model.predict(text, k=k)


class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):

    config_class = ImpressoConfig
    # Monkey-patch get_parameter_device

    def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
        super().__init__(config)
        self.config = config
        print("Doest is it even pass through here?")
        print(
            f"The config in ExtendedMultitaskModelForTokenClassification is: {self.config}"
        )
        # self.model = floret.load_model(self.config.filename)

    def predict(self, text, k=1):
        predictions = self.model.predict(text, k)
        return predictions

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        print("Calling from_pretrained...")

        # Initialize model with config
        model = cls(ImpressoConfig())

        # Load model using floret
        print(f"---Loading model from: {model.config.filename}")
        floret_model = floret.load_model(model.config.filename)

        # Wrap the model to fake .device attribute
        model.model = FloretModelWrapper(floret_model)

        print(model.model, "device:", model.model.device)

        print(f"Model loaded and wrapped from: {model.config.filename}")

        return model

    def save_pretrained(self, save_directory, *args, **kwargs):
        # Ignore Hugging Face-specific arguments
        max_shard_size = kwargs.pop("max_shard_size", None)
        safe_serialization = kwargs.pop("safe_serialization", False)

        # Ensure directory exists
        os.makedirs(save_directory, exist_ok=True)

        # Save the model file
        model_file = os.path.join(save_directory, "LID-40-3-2000000-1-4.bin")
        shutil.copy(self.config.filename, model_file)

        # Save the config file
        config_file = os.path.join(save_directory, "config.json")
        self.config.save_pretrained(save_directory)

        print(f"Model saved to: {save_directory}")

    def get_parameter_device(module):
        """
        Custom get_parameter_device() to handle floret models.
        Returns 'cpu' for floret models, and falls back to the original method otherwise.
        """
        # Check if the model is an instance of your FloretModelWrapper
        if isinstance(module, FloretModelWrapper):
            print(
                "Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
            )
            return "cpu"

        # Otherwise, fall back to Hugging Face's original implementation
        return original_get_parameter_device(module)