Safetensors
mgte-arctic-s / modeling_mgte_arctic_s.py
michaeldinzinger's picture
Initial commit
2c9b2dc
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, XLMRobertaTokenizerFast, \
AutoModel, PreTrainedTokenizerFast, BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoTokenizer, XLMRobertaTokenizer
from typing import *
class ConcatModelConfig(PretrainedConfig):
model_type = "mgte-arctic-s"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class ConcatModel(PreTrainedModel):
config_class = ConcatModelConfig
def __init__(self, models):
super().__init__(ConcatModelConfig())
self.models = models
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
embeddings = []
for i, model in enumerate(self.models):
if i == 0:
model_output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
else:
model_output = model(
input_ids=kwargs["input_ids_" + str(i)],
attention_mask=kwargs["attention_mask_" + str(i)],
token_type_ids=kwargs.get("token_type_ids_" + str(i)),
)
pooled_output = model_output[0][:, 0]
pooled_output = F.normalize(pooled_output, p=2, dim=-1)
embeddings.append(pooled_output)
return torch.cat(embeddings, dim=-1)
def save_pretrained(self, save_directory):
for i, model in enumerate(self.models):
path = os.path.join(save_directory, f"model_{i}")
model.save_pretrained(path)
# @classmethod
# def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# print("ConcatModel from_pretrained")
# print(pretrained_model_name_or_path)
# x = super().from_pretrained(pretrained_model_name_or_path)
# print(x)
# models = []
# return cls(models)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, model_id="model_0", **kwargs):
# Identify the subfolder for the model
model_path = f"{pretrained_model_name_or_path}/{model_id}"
print("---- model_path:", model_path)
# Load model configuration
config = cls.config_class.from_pretrained(model_path, **kwargs)
print("---- config:", config)
# Initialize the model
model = cls(config, **kwargs)
# Load weights
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
return model
def __repr__(self):
s = "ConcatModel with models:"
for i, model in enumerate(self.models):
s += f"\nModel {i}: {model}"
return s
def eval(self):
for model in self.models:
model.eval()
return self
def cuda(self):
for i, model in enumerate(self.models):
self.models[i] = model.cuda()
return self
class ConcatTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer to handle multiple tokenizers for concatenated models.
This tokenizer will delegate tokenization to the underlying individual tokenizers.
"""
def __init__(self, tokenizers, **kwargs):
self.tokenizers = tokenizers
def tokenize(self, text: str, **kwargs):
"""
Tokenizes text using all tokenizers.
"""
return [tokenizer.tokenize(text, **kwargs) for tokenizer in self.tokenizers]
def __call__(self, text, **kwargs):
"""
Tokenize and encode input text using all tokenizers.
Returns combined inputs.
"""
combined_inputs = {}
for i, tokenizer in enumerate(self.tokenizers):
encoded = tokenizer(text, **kwargs)
# Prefix the keys to distinguish between tokenizers
for key, value in encoded.items():
_key = key
if i > 0:
_key = f"{key}_{i}"
combined_inputs[_key] = value
return combined_inputs
def batch_encode_plus(self, batch_text_or_text_pairs, **kwargs):
"""
Handles batch tokenization for all tokenizers.
"""
combined_inputs = {}
for i, tokenizer in enumerate(self.tokenizers):
encoded_batch = tokenizer.batch_encode_plus(batch_text_or_text_pairs, **kwargs)
for key, value in encoded_batch.items():
_key = key
if i > 0:
_key = f"{key}_{i}"
combined_inputs[_key] = value
return combined_inputs
def decode(self, token_ids, **kwargs):
"""
Decode tokens using the first tokenizer (or specific one, if required).
"""
# Choose the primary tokenizer for decoding (default: model_0)
return self.tokenizers[0].decode(token_ids, **kwargs)
def save_pretrained(self, save_directory):
"""
Save the tokenizers to the specified directory.
"""
for i, tokenizer in enumerate(self.tokenizers):
path = os.path.join(save_directory, f"model_{i}")
tokenizer.save_pretrained(path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the tokenizers from the specified directory.
"""
tokenizers = [
XLMRobertaTokenizerFast.from_pretrained(f"{pretrained_model_name_or_path}/model_0"),
BertTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/model_1")
]
return cls(tokenizers)
def __repr__(self):
s = "ConcatTokenizer with tokenizers:"
for i, tokenizer in enumerate(self.tokenizers):
s += f"\nTokenizer {i}: {tokenizer}"
return s