|
|
import torch |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
def build_text_encoder(config): |
|
|
if config.model_type == "mpnet": |
|
|
model = AutoModel.from_pretrained(config.pretrained_name_or_path) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
|
token_embeddings = model_output[ |
|
|
0 |
|
|
] |
|
|
input_mask_expanded = ( |
|
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
) |
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
|
input_mask_expanded.sum(1), min=1e-9 |
|
|
) |
|
|
|