import collections import logging import random from typing import Tuple, List import numpy as np import torch import torch.nn.functional as F from torch import Tensor as T from torch import nn import sys import os current_dir = os.path.dirname(__file__) data_utils_path = os.path.join(current_dir, '..') sys.path.append(data_utils_path) from Data_utils_inf import Tensorizer from Data_utils_inf import normalize_question logger = logging.getLogger(__name__) BiEncoderBatch = collections.namedtuple( "BiENcoderInput", [ "question_ids", "question_segments", "context_ids", "ctx_segments", "is_positive", "hard_negatives", ], ) def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: """ calculates q->ctx scores for every row in ctx_vector :param q_vector: :param ctx_vector: :return: """ # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) return r def cosine_scores(q_vector: T, ctx_vectors: T): # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 return F.cosine_similarity(q_vector, ctx_vectors, dim=1) class BiEncoder(nn.Module): """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" def __init__( self, question_model: nn.Module, ctx_model: nn.Module, fix_q_encoder: bool = False, fix_ctx_encoder: bool = False, ): super(BiEncoder, self).__init__() self.question_model = question_model self.ctx_model = ctx_model self.fix_q_encoder = fix_q_encoder self.fix_ctx_encoder = fix_ctx_encoder @staticmethod def get_representation( sub_model: nn.Module, ids: T, segments: T, attn_mask: T, fix_encoder: bool = False, ) -> (T, T, T): sequence_output = None pooled_output = None hidden_states = None if ids is not None: if fix_encoder: with torch.no_grad(): sequence_output, pooled_output, hidden_states = sub_model( ids, segments, attn_mask ) if sub_model.training: sequence_output.requires_grad_(requires_grad=True) pooled_output.requires_grad_(requires_grad=True) else: sequence_output, pooled_output, hidden_states = sub_model( ids, segments, attn_mask ) return sequence_output, pooled_output, hidden_states def forward( self, question_ids: T, question_segments: T, question_attn_mask: T, context_ids: T, ctx_segments: T, ctx_attn_mask: T, ) -> Tuple[T, T]: _q_seq, q_pooled_out, _q_hidden = self.get_representation( self.question_model, question_ids, question_segments, question_attn_mask, self.fix_q_encoder, ) _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( self.ctx_model, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder, ) return q_pooled_out, ctx_pooled_out @classmethod def create_biencoder_input( cls, samples: List, tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, do_lower_fill: bool = False, desegment_valid_fill: bool =False ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of data items (from json) to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample["positive_ctxs"] positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] else: positive_ctx = sample["positive_ctxs"][0] if do_lower_fill: positive_ctx["text"] = positive_ctx["text"].lower() neg_ctxs = sample["negative_ctxs"] hard_neg_ctxs = sample["hard_negative_ctxs"] if do_lower_fill: neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs)) question = normalize_question(sample["question"]) if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 hard_negatives_end_idx = 1 + len(hard_neg_ctxs) current_ctxs_len = len(ctx_tensors) sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx["text"], title=ctx["title"] if insert_title else None ) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append( [ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ] ) question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, ) class DistilBertBiEncoder(nn.Module): """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" def __init__( self, question_model: nn.Module, ctx_model: nn.Module, fix_q_encoder: bool = False, fix_ctx_encoder: bool = False, ): super(DistilBertBiEncoder, self).__init__() self.question_model = question_model self.ctx_model = ctx_model self.fix_q_encoder = fix_q_encoder self.fix_ctx_encoder = fix_ctx_encoder @staticmethod def get_representation( sub_model: nn.Module, ids: T, segments: T, attn_mask: T, fix_encoder: bool = False, ) -> (T, T, T): sequence_output = None pooled_output = None hidden_states = None if ids is not None: if fix_encoder: with torch.no_grad(): sequence_output, pooled_output, hidden_states = sub_model( # ids, segments, attn_mask ids, attn_mask ) if sub_model.training: sequence_output.requires_grad_(requires_grad=True) pooled_output.requires_grad_(requires_grad=True) else: sequence_output, pooled_output, hidden_states = sub_model( # ids, segments, attn_mask ids, attn_mask ) return sequence_output, pooled_output, hidden_states def forward( self, question_ids: T, question_segments: T, question_attn_mask: T, context_ids: T, ctx_segments: T, ctx_attn_mask: T, ) -> Tuple[T, T]: _q_seq, q_pooled_out, _q_hidden = self.get_representation( self.question_model, question_ids, question_segments, question_attn_mask, self.fix_q_encoder, ) _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( self.ctx_model, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder, ) return q_pooled_out, ctx_pooled_out @classmethod def create_biencoder_input( cls, samples: List, tensorizer: Tensorizer, insert_title: bool, num_hard_negatives: int = 0, num_other_negatives: int = 0, shuffle: bool = True, shuffle_positives: bool = False, do_lower_fill: bool = False, desegment_valid_fill: bool =False ) -> BiEncoderBatch: """ Creates a batch of the biencoder training tuple. :param samples: list of data items (from json) to create the batch for :param tensorizer: components to create model input tensors from a text sequence :param insert_title: enables title insertion at the beginning of the context sequences :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) :param num_other_negatives: amount of other negatives per question (taken from samples' pools) :param shuffle: shuffles negative passages pools :param shuffle_positives: shuffles positive passages pools :return: BiEncoderBatch tuple """ question_tensors = [] ctx_tensors = [] positive_ctx_indices = [] hard_neg_ctx_indices = [] for sample in samples: # ctx+ & [ctx-] composition # as of now, take the first(gold) ctx+ only if shuffle and shuffle_positives: positive_ctxs = sample["positive_ctxs"] positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] else: positive_ctx = sample["positive_ctxs"][0] if do_lower_fill: positive_ctx["text"] = positive_ctx["text"].lower() neg_ctxs = sample["negative_ctxs"] hard_neg_ctxs = sample["hard_negative_ctxs"] if do_lower_fill: neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs)) question = normalize_question(sample["question"]) if shuffle: random.shuffle(neg_ctxs) random.shuffle(hard_neg_ctxs) neg_ctxs = neg_ctxs[0:num_other_negatives] hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs hard_negatives_start_idx = 1 hard_negatives_end_idx = 1 + len(hard_neg_ctxs) current_ctxs_len = len(ctx_tensors) sample_ctxs_tensors = [ tensorizer.text_to_tensor( ctx["text"], title=ctx["title"] if insert_title else None ) for ctx in all_ctxs ] ctx_tensors.extend(sample_ctxs_tensors) positive_ctx_indices.append(current_ctxs_len) hard_neg_ctx_indices.append( [ i for i in range( current_ctxs_len + hard_negatives_start_idx, current_ctxs_len + hard_negatives_end_idx, ) ] ) question_tensors.append(tensorizer.text_to_tensor(question)) ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) ctx_segments = torch.zeros_like(ctxs_tensor) question_segments = torch.zeros_like(questions_tensor) return BiEncoderBatch( questions_tensor, question_segments, ctxs_tensor, ctx_segments, positive_ctx_indices, hard_neg_ctx_indices, ) class BiEncoderNllLoss(object): def calc( self, q_vectors: T, ctx_vectors: T, positive_idx_per_question: list, hard_negatice_idx_per_question: list = None, ) -> Tuple[T, int]: """ Computes nll loss for the given lists of question and ctx vectors. Note that although hard_negative_idx_per_question in not currently in use, one can use it for the loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. :return: a tuple of loss value and amount of correct predictions per batch """ scores = self.get_scores(q_vectors, ctx_vectors) if len(q_vectors.size()) > 1: q_num = q_vectors.size(0) scores = scores.view(q_num, -1) softmax_scores = F.log_softmax(scores, dim=1) loss = F.nll_loss( softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device), reduction="mean", ) max_score, max_idxs = torch.max(softmax_scores, 1) correct_predictions_count = ( max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device) ).sum() return loss, correct_predictions_count @staticmethod def get_scores(q_vector: T, ctx_vectors: T) -> T: f = BiEncoderNllLoss.get_similarity_function() return f(q_vector, ctx_vectors) @staticmethod def get_similarity_function(): return dot_product_scores