| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, XLMRobertaTokenizerFast, \ |
| AutoModel, PreTrainedTokenizerFast, BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoTokenizer, XLMRobertaTokenizer |
| from typing import * |
|
|
|
|
| class ConcatModelConfig(PretrainedConfig): |
| model_type = "mgte-arctic-s" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
|
|
| class ConcatModel(PreTrainedModel): |
| config_class = ConcatModelConfig |
|
|
| def __init__(self, models): |
| super().__init__(ConcatModelConfig()) |
| self.models = models |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: torch.Tensor = None, |
| **kwargs |
| ) -> torch.Tensor: |
| embeddings = [] |
| for i, model in enumerate(self.models): |
| if i == 0: |
| model_output = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| else: |
| model_output = model( |
| input_ids=kwargs["input_ids_" + str(i)], |
| attention_mask=kwargs["attention_mask_" + str(i)], |
| token_type_ids=kwargs.get("token_type_ids_" + str(i)), |
| ) |
| pooled_output = model_output[0][:, 0] |
| pooled_output = F.normalize(pooled_output, p=2, dim=-1) |
| embeddings.append(pooled_output) |
|
|
| return torch.cat(embeddings, dim=-1) |
| |
| def save_pretrained(self, save_directory): |
| for i, model in enumerate(self.models): |
| path = os.path.join(save_directory, f"model_{i}") |
| model.save_pretrained(path) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, model_id="model_0", **kwargs): |
| |
| model_path = f"{pretrained_model_name_or_path}/{model_id}" |
| print("---- model_path:", model_path) |
|
|
| |
| config = cls.config_class.from_pretrained(model_path, **kwargs) |
| print("---- config:", config) |
|
|
| |
| model = cls(config, **kwargs) |
|
|
| |
| state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") |
| model.load_state_dict(state_dict) |
|
|
| return model |
| |
| def __repr__(self): |
| s = "ConcatModel with models:" |
| for i, model in enumerate(self.models): |
| s += f"\nModel {i}: {model}" |
| return s |
| |
| def eval(self): |
| for model in self.models: |
| model.eval() |
| return self |
| |
| def cuda(self): |
| for i, model in enumerate(self.models): |
| self.models[i] = model.cuda() |
| return self |
|
|
| class ConcatTokenizer(PreTrainedTokenizer): |
| """ |
| A custom tokenizer to handle multiple tokenizers for concatenated models. |
| This tokenizer will delegate tokenization to the underlying individual tokenizers. |
| """ |
| |
| def __init__(self, tokenizers, **kwargs): |
| self.tokenizers = tokenizers |
|
|
| def tokenize(self, text: str, **kwargs): |
| """ |
| Tokenizes text using all tokenizers. |
| """ |
| return [tokenizer.tokenize(text, **kwargs) for tokenizer in self.tokenizers] |
| |
| def __call__(self, text, **kwargs): |
| """ |
| Tokenize and encode input text using all tokenizers. |
| Returns combined inputs. |
| """ |
| combined_inputs = {} |
| for i, tokenizer in enumerate(self.tokenizers): |
| encoded = tokenizer(text, **kwargs) |
| |
| for key, value in encoded.items(): |
| _key = key |
| if i > 0: |
| _key = f"{key}_{i}" |
| combined_inputs[_key] = value |
| |
| return combined_inputs |
| |
| def batch_encode_plus(self, batch_text_or_text_pairs, **kwargs): |
| """ |
| Handles batch tokenization for all tokenizers. |
| """ |
| combined_inputs = {} |
| for i, tokenizer in enumerate(self.tokenizers): |
| encoded_batch = tokenizer.batch_encode_plus(batch_text_or_text_pairs, **kwargs) |
| for key, value in encoded_batch.items(): |
| _key = key |
| if i > 0: |
| _key = f"{key}_{i}" |
| combined_inputs[_key] = value |
|
|
| return combined_inputs |
|
|
| def decode(self, token_ids, **kwargs): |
| """ |
| Decode tokens using the first tokenizer (or specific one, if required). |
| """ |
| |
| return self.tokenizers[0].decode(token_ids, **kwargs) |
| |
| def save_pretrained(self, save_directory): |
| """ |
| Save the tokenizers to the specified directory. |
| """ |
| for i, tokenizer in enumerate(self.tokenizers): |
| path = os.path.join(save_directory, f"model_{i}") |
| tokenizer.save_pretrained(path) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| """ |
| Load the tokenizers from the specified directory. |
| """ |
| tokenizers = [ |
| XLMRobertaTokenizerFast.from_pretrained(f"{pretrained_model_name_or_path}/model_0"), |
| BertTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/model_1") |
| ] |
| return cls(tokenizers) |
| |
| def __repr__(self): |
| s = "ConcatTokenizer with tokenizers:" |
| for i, tokenizer in enumerate(self.tokenizers): |
| s += f"\nTokenizer {i}: {tokenizer}" |
| return s |
|
|