File size: 736 Bytes
4333430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
from transformers import AutoModel
def build_text_encoder(config):
if config.model_type == "mpnet":
model = AutoModel.from_pretrained(config.pretrained_name_or_path)
else:
raise NotImplementedError()
return model
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
|