|
|
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig |
|
|
import floret, torch |
|
|
import os, shutil |
|
|
from configuration_stacked import ImpressoConfig |
|
|
from transformers.modeling_utils import ( |
|
|
get_parameter_device as original_get_parameter_device, |
|
|
) |
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
import transformers.modeling_utils |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_utils import ( |
|
|
get_parameter_device as original_get_parameter_device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def custom_get_parameter_device(module): |
|
|
""" |
|
|
Custom get_parameter_device() to handle floret models. |
|
|
Returns 'cpu' for FloretModelWrapper, otherwise uses the original implementation. |
|
|
""" |
|
|
|
|
|
if isinstance(module, FloretModelWrapper): |
|
|
print( |
|
|
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'." |
|
|
) |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
return original_get_parameter_device(module) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def custom_device(self) -> torch.device: |
|
|
""" |
|
|
Custom device() method to handle floret models. |
|
|
Always returns torch.device('cpu') for FloretModelWrapper. |
|
|
""" |
|
|
|
|
|
if isinstance(self, FloretModelWrapper): |
|
|
print( |
|
|
"Custom device(): Detected FloretModelWrapper. Returning torch.device('cpu')." |
|
|
) |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
transformers.modeling_utils.get_parameter_device = custom_get_parameter_device |
|
|
PreTrainedModel.device = custom_device |
|
|
|
|
|
print("Monkey-patch applied: get_parameter_device and device property") |
|
|
|
|
|
|
|
|
|
|
|
original_device = PreTrainedModel.device |
|
|
|
|
|
|
|
|
def get_info(label_map): |
|
|
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} |
|
|
return num_token_labels_dict |
|
|
|
|
|
|
|
|
class FloretModelWrapper: |
|
|
""" |
|
|
Wrapper for floret model to make it compatible with Hugging Face pipeline. |
|
|
Mocks the .device attribute and passes predict() unchanged. |
|
|
""" |
|
|
|
|
|
def __init__(self, floret_model): |
|
|
self.floret_model = floret_model |
|
|
|
|
|
|
|
|
self.device = torch.device("cpu") |
|
|
|
|
|
def predict(self, text, k=1): |
|
|
""" |
|
|
Pass-through for floret's predict() method. |
|
|
""" |
|
|
return self.floret_model.predict(text, k=k) |
|
|
|
|
|
|
|
|
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): |
|
|
|
|
|
config_class = ImpressoConfig |
|
|
|
|
|
|
|
|
def __init__(self, config: PretrainedConfig, *inputs, **kwargs): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
print("Doest is it even pass through here?") |
|
|
print( |
|
|
f"The config in ExtendedMultitaskModelForTokenClassification is: {self.config}" |
|
|
) |
|
|
|
|
|
|
|
|
def predict(self, text, k=1): |
|
|
predictions = self.model.predict(text, k) |
|
|
return predictions |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
print("Calling from_pretrained...") |
|
|
|
|
|
|
|
|
model = cls(ImpressoConfig()) |
|
|
|
|
|
|
|
|
print(f"---Loading model from: {model.config.filename}") |
|
|
floret_model = floret.load_model(model.config.filename) |
|
|
|
|
|
|
|
|
model.model = FloretModelWrapper(floret_model) |
|
|
|
|
|
print(model.model, "device:", model.model.device) |
|
|
|
|
|
print(f"Model loaded and wrapped from: {model.config.filename}") |
|
|
|
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory, *args, **kwargs): |
|
|
|
|
|
max_shard_size = kwargs.pop("max_shard_size", None) |
|
|
safe_serialization = kwargs.pop("safe_serialization", False) |
|
|
|
|
|
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
model_file = os.path.join(save_directory, "LID-40-3-2000000-1-4.bin") |
|
|
shutil.copy(self.config.filename, model_file) |
|
|
|
|
|
|
|
|
config_file = os.path.join(save_directory, "config.json") |
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
print(f"Model saved to: {save_directory}") |
|
|
|
|
|
def get_parameter_device(module): |
|
|
""" |
|
|
Custom get_parameter_device() to handle floret models. |
|
|
Returns 'cpu' for floret models, and falls back to the original method otherwise. |
|
|
""" |
|
|
|
|
|
if isinstance(module, FloretModelWrapper): |
|
|
print( |
|
|
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'." |
|
|
) |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
return original_get_parameter_device(module) |
|
|
|