Safetensors
custom_code
RadZero / text_encoders.py
jonggwon-park's picture
auto model bug fix
2ba7893
import torch
from transformers import AutoModel
def build_text_encoder(config):
if config.model_type == "mpnet":
model = AutoModel.from_pretrained(config.pretrained_name_or_path)
else:
raise NotImplementedError()
return model
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)