| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, AutoModel |
| from typing import * |
|
|
|
|
| class ConcatModelConfig(PretrainedConfig): |
| model_type = "arctic-s-bge-small" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
|
|
| |
| class ConcatModel(PreTrainedModel): |
| config_class = ConcatModelConfig |
|
|
| def __init__(self, config: ConcatModelConfig): |
| super().__init__(config) |
| bert_config_1 = BertConfig( |
| vocab_size=30522, |
| hidden_size=384, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=1536, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| max_position_embeddings=512, |
| type_vocab_size=2, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| ) |
|
|
| bert_config_2 = BertConfig( |
| vocab_size=30522, |
| hidden_size=384, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=1536, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| max_position_embeddings=512, |
| type_vocab_size=2, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| ) |
|
|
| self.model = nn.ModuleDict( |
| { |
| "model_0": BertModel(bert_config_1), |
| "model_1": BertModel(bert_config_2), |
| } |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: torch.Tensor = None, |
| **kwargs |
| ) -> torch.Tensor: |
| embeddings = [] |
| for _, model in self.model.items(): |
| model_output = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| ) |
| pooled_output = model_output[0][:, 0] |
| pooled_output = F.normalize(pooled_output, p=2, dim=-1) |
| embeddings.append(pooled_output) |
|
|
| return torch.cat(embeddings, dim=-1) |
|
|
| def load_weights_from_automodels( |
| self, in_models: List[str], has_pooling_layer: List[bool] |
| ): |
| model_list = [] |
| for i, model_name in enumerate(in_models): |
| model = AutoModel.from_pretrained( |
| model_name, |
| add_pooling_layer=has_pooling_layer[i], |
| trust_remote_code=True, |
| ) |
| model.eval() |
| model_list.append(model) |
|
|
| self.model = nn.ModuleDict( |
| {f"model_{i}": model for i, model in enumerate(model_list)} |
| ) |
|
|