lang-detect / modeling_stacked.py
Glebs Vinarskis
Initial commit including model and configuration
26b1bda
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
import floret, torch
import os, shutil
from configuration_stacked import ImpressoConfig
from transformers.modeling_utils import (
get_parameter_device as original_get_parameter_device,
)
import torch
# Import Hugging Face dependencies
import transformers.modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import (
get_parameter_device as original_get_parameter_device,
)
# Custom get_parameter_device
def custom_get_parameter_device(module):
"""
Custom get_parameter_device() to handle floret models.
Returns 'cpu' for FloretModelWrapper, otherwise uses the original implementation.
"""
# Check if the model is an instance of your FloretModelWrapper
if isinstance(module, FloretModelWrapper):
print(
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
)
return torch.device("cpu")
# Otherwise, fall back to Hugging Face's original implementation
return original_get_parameter_device(module)
# Custom device property
@property
def custom_device(self) -> torch.device:
"""
Custom device() method to handle floret models.
Always returns torch.device('cpu') for FloretModelWrapper.
"""
# Check if the model is an instance of your FloretModelWrapper
if isinstance(self, FloretModelWrapper):
print(
"Custom device(): Detected FloretModelWrapper. Returning torch.device('cpu')."
)
return torch.device("cpu")
# Otherwise, fall back to Hugging Face's original implementation
return torch.device("cpu") # original_device.__get__(self, type(self))
# Monkey-patch get_parameter_device and device property
transformers.modeling_utils.get_parameter_device = custom_get_parameter_device
PreTrainedModel.device = custom_device
print("Monkey-patch applied: get_parameter_device and device property")
# logger = logging.getLogger(__name__)
original_device = PreTrainedModel.device
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class FloretModelWrapper:
"""
Wrapper for floret model to make it compatible with Hugging Face pipeline.
Mocks the .device attribute and passes predict() unchanged.
"""
def __init__(self, floret_model):
self.floret_model = floret_model
# Mocking the .device attribute to make Hugging Face happy
self.device = torch.device("cpu") # floret is always on CPU
def predict(self, text, k=1):
"""
Pass-through for floret's predict() method.
"""
return self.floret_model.predict(text, k=k)
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
config_class = ImpressoConfig
# Monkey-patch get_parameter_device
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__(config)
self.config = config
print("Doest is it even pass through here?")
print(
f"The config in ExtendedMultitaskModelForTokenClassification is: {self.config}"
)
# self.model = floret.load_model(self.config.filename)
def predict(self, text, k=1):
predictions = self.model.predict(text, k)
return predictions
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
print("Calling from_pretrained...")
# Initialize model with config
model = cls(ImpressoConfig())
# Load model using floret
print(f"---Loading model from: {model.config.filename}")
floret_model = floret.load_model(model.config.filename)
# Wrap the model to fake .device attribute
model.model = FloretModelWrapper(floret_model)
print(model.model, "device:", model.model.device)
print(f"Model loaded and wrapped from: {model.config.filename}")
return model
def save_pretrained(self, save_directory, *args, **kwargs):
# Ignore Hugging Face-specific arguments
max_shard_size = kwargs.pop("max_shard_size", None)
safe_serialization = kwargs.pop("safe_serialization", False)
# Ensure directory exists
os.makedirs(save_directory, exist_ok=True)
# Save the model file
model_file = os.path.join(save_directory, "LID-40-3-2000000-1-4.bin")
shutil.copy(self.config.filename, model_file)
# Save the config file
config_file = os.path.join(save_directory, "config.json")
self.config.save_pretrained(save_directory)
print(f"Model saved to: {save_directory}")
def get_parameter_device(module):
"""
Custom get_parameter_device() to handle floret models.
Returns 'cpu' for floret models, and falls back to the original method otherwise.
"""
# Check if the model is an instance of your FloretModelWrapper
if isinstance(module, FloretModelWrapper):
print(
"Custom get_parameter_device(): Detected FloretModelWrapper. Returning 'cpu'."
)
return "cpu"
# Otherwise, fall back to Hugging Face's original implementation
return original_get_parameter_device(module)