import torch from transformers import AutoModel def build_text_encoder(config): if config.model_type == "mpnet": model = AutoModel.from_pretrained(config.pretrained_name_or_path) else: raise NotImplementedError() return model # Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(model_output, attention_mask): token_embeddings = model_output[ 0 ] # First element of model_output contains all token embeddings input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 )