Spaces:
Runtime error
Runtime error
| import collections | |
| import logging | |
| import random | |
| from typing import Tuple, List | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor as T | |
| from torch import nn | |
| import sys | |
| import os | |
| current_dir = os.path.dirname(__file__) | |
| data_utils_path = os.path.join(current_dir, '..') | |
| sys.path.append(data_utils_path) | |
| from Data_utils_inf import Tensorizer | |
| from Data_utils_inf import normalize_question | |
| logger = logging.getLogger(__name__) | |
| BiEncoderBatch = collections.namedtuple( | |
| "BiENcoderInput", | |
| [ | |
| "question_ids", | |
| "question_segments", | |
| "context_ids", | |
| "ctx_segments", | |
| "is_positive", | |
| "hard_negatives", | |
| ], | |
| ) | |
| def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: | |
| """ | |
| calculates q->ctx scores for every row in ctx_vector | |
| :param q_vector: | |
| :param ctx_vector: | |
| :return: | |
| """ | |
| # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 | |
| r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) | |
| return r | |
| def cosine_scores(q_vector: T, ctx_vectors: T): | |
| # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 | |
| return F.cosine_similarity(q_vector, ctx_vectors, dim=1) | |
| class BiEncoder(nn.Module): | |
| """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" | |
| def __init__( | |
| self, | |
| question_model: nn.Module, | |
| ctx_model: nn.Module, | |
| fix_q_encoder: bool = False, | |
| fix_ctx_encoder: bool = False, | |
| ): | |
| super(BiEncoder, self).__init__() | |
| self.question_model = question_model | |
| self.ctx_model = ctx_model | |
| self.fix_q_encoder = fix_q_encoder | |
| self.fix_ctx_encoder = fix_ctx_encoder | |
| def get_representation( | |
| sub_model: nn.Module, | |
| ids: T, | |
| segments: T, | |
| attn_mask: T, | |
| fix_encoder: bool = False, | |
| ) -> (T, T, T): | |
| sequence_output = None | |
| pooled_output = None | |
| hidden_states = None | |
| if ids is not None: | |
| if fix_encoder: | |
| with torch.no_grad(): | |
| sequence_output, pooled_output, hidden_states = sub_model( | |
| ids, segments, attn_mask | |
| ) | |
| if sub_model.training: | |
| sequence_output.requires_grad_(requires_grad=True) | |
| pooled_output.requires_grad_(requires_grad=True) | |
| else: | |
| sequence_output, pooled_output, hidden_states = sub_model( | |
| ids, segments, attn_mask | |
| ) | |
| return sequence_output, pooled_output, hidden_states | |
| def forward( | |
| self, | |
| question_ids: T, | |
| question_segments: T, | |
| question_attn_mask: T, | |
| context_ids: T, | |
| ctx_segments: T, | |
| ctx_attn_mask: T, | |
| ) -> Tuple[T, T]: | |
| _q_seq, q_pooled_out, _q_hidden = self.get_representation( | |
| self.question_model, | |
| question_ids, | |
| question_segments, | |
| question_attn_mask, | |
| self.fix_q_encoder, | |
| ) | |
| _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( | |
| self.ctx_model, | |
| context_ids, | |
| ctx_segments, | |
| ctx_attn_mask, | |
| self.fix_ctx_encoder, | |
| ) | |
| return q_pooled_out, ctx_pooled_out | |
| def create_biencoder_input( | |
| cls, | |
| samples: List, | |
| tensorizer: Tensorizer, | |
| insert_title: bool, | |
| num_hard_negatives: int = 0, | |
| num_other_negatives: int = 0, | |
| shuffle: bool = True, | |
| shuffle_positives: bool = False, | |
| do_lower_fill: bool = False, | |
| desegment_valid_fill: bool =False | |
| ) -> BiEncoderBatch: | |
| """ | |
| Creates a batch of the biencoder training tuple. | |
| :param samples: list of data items (from json) to create the batch for | |
| :param tensorizer: components to create model input tensors from a text sequence | |
| :param insert_title: enables title insertion at the beginning of the context sequences | |
| :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) | |
| :param num_other_negatives: amount of other negatives per question (taken from samples' pools) | |
| :param shuffle: shuffles negative passages pools | |
| :param shuffle_positives: shuffles positive passages pools | |
| :return: BiEncoderBatch tuple | |
| """ | |
| question_tensors = [] | |
| ctx_tensors = [] | |
| positive_ctx_indices = [] | |
| hard_neg_ctx_indices = [] | |
| for sample in samples: | |
| # ctx+ & [ctx-] composition | |
| # as of now, take the first(gold) ctx+ only | |
| if shuffle and shuffle_positives: | |
| positive_ctxs = sample["positive_ctxs"] | |
| positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] | |
| else: | |
| positive_ctx = sample["positive_ctxs"][0] | |
| if do_lower_fill: | |
| positive_ctx["text"] = positive_ctx["text"].lower() | |
| neg_ctxs = sample["negative_ctxs"] | |
| hard_neg_ctxs = sample["hard_negative_ctxs"] | |
| if do_lower_fill: | |
| neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs)) | |
| question = normalize_question(sample["question"]) | |
| if shuffle: | |
| random.shuffle(neg_ctxs) | |
| random.shuffle(hard_neg_ctxs) | |
| neg_ctxs = neg_ctxs[0:num_other_negatives] | |
| hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] | |
| all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs | |
| hard_negatives_start_idx = 1 | |
| hard_negatives_end_idx = 1 + len(hard_neg_ctxs) | |
| current_ctxs_len = len(ctx_tensors) | |
| sample_ctxs_tensors = [ | |
| tensorizer.text_to_tensor( | |
| ctx["text"], title=ctx["title"] if insert_title else None | |
| ) | |
| for ctx in all_ctxs | |
| ] | |
| ctx_tensors.extend(sample_ctxs_tensors) | |
| positive_ctx_indices.append(current_ctxs_len) | |
| hard_neg_ctx_indices.append( | |
| [ | |
| i | |
| for i in range( | |
| current_ctxs_len + hard_negatives_start_idx, | |
| current_ctxs_len + hard_negatives_end_idx, | |
| ) | |
| ] | |
| ) | |
| question_tensors.append(tensorizer.text_to_tensor(question)) | |
| ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) | |
| questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) | |
| ctx_segments = torch.zeros_like(ctxs_tensor) | |
| question_segments = torch.zeros_like(questions_tensor) | |
| return BiEncoderBatch( | |
| questions_tensor, | |
| question_segments, | |
| ctxs_tensor, | |
| ctx_segments, | |
| positive_ctx_indices, | |
| hard_neg_ctx_indices, | |
| ) | |
| class DistilBertBiEncoder(nn.Module): | |
| """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" | |
| def __init__( | |
| self, | |
| question_model: nn.Module, | |
| ctx_model: nn.Module, | |
| fix_q_encoder: bool = False, | |
| fix_ctx_encoder: bool = False, | |
| ): | |
| super(DistilBertBiEncoder, self).__init__() | |
| self.question_model = question_model | |
| self.ctx_model = ctx_model | |
| self.fix_q_encoder = fix_q_encoder | |
| self.fix_ctx_encoder = fix_ctx_encoder | |
| def get_representation( | |
| sub_model: nn.Module, | |
| ids: T, | |
| segments: T, | |
| attn_mask: T, | |
| fix_encoder: bool = False, | |
| ) -> (T, T, T): | |
| sequence_output = None | |
| pooled_output = None | |
| hidden_states = None | |
| if ids is not None: | |
| if fix_encoder: | |
| with torch.no_grad(): | |
| sequence_output, pooled_output, hidden_states = sub_model( | |
| # ids, segments, attn_mask | |
| ids, attn_mask | |
| ) | |
| if sub_model.training: | |
| sequence_output.requires_grad_(requires_grad=True) | |
| pooled_output.requires_grad_(requires_grad=True) | |
| else: | |
| sequence_output, pooled_output, hidden_states = sub_model( | |
| # ids, segments, attn_mask | |
| ids, attn_mask | |
| ) | |
| return sequence_output, pooled_output, hidden_states | |
| def forward( | |
| self, | |
| question_ids: T, | |
| question_segments: T, | |
| question_attn_mask: T, | |
| context_ids: T, | |
| ctx_segments: T, | |
| ctx_attn_mask: T, | |
| ) -> Tuple[T, T]: | |
| _q_seq, q_pooled_out, _q_hidden = self.get_representation( | |
| self.question_model, | |
| question_ids, | |
| question_segments, | |
| question_attn_mask, | |
| self.fix_q_encoder, | |
| ) | |
| _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( | |
| self.ctx_model, | |
| context_ids, | |
| ctx_segments, | |
| ctx_attn_mask, | |
| self.fix_ctx_encoder, | |
| ) | |
| return q_pooled_out, ctx_pooled_out | |
| def create_biencoder_input( | |
| cls, | |
| samples: List, | |
| tensorizer: Tensorizer, | |
| insert_title: bool, | |
| num_hard_negatives: int = 0, | |
| num_other_negatives: int = 0, | |
| shuffle: bool = True, | |
| shuffle_positives: bool = False, | |
| do_lower_fill: bool = False, | |
| desegment_valid_fill: bool =False | |
| ) -> BiEncoderBatch: | |
| """ | |
| Creates a batch of the biencoder training tuple. | |
| :param samples: list of data items (from json) to create the batch for | |
| :param tensorizer: components to create model input tensors from a text sequence | |
| :param insert_title: enables title insertion at the beginning of the context sequences | |
| :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) | |
| :param num_other_negatives: amount of other negatives per question (taken from samples' pools) | |
| :param shuffle: shuffles negative passages pools | |
| :param shuffle_positives: shuffles positive passages pools | |
| :return: BiEncoderBatch tuple | |
| """ | |
| question_tensors = [] | |
| ctx_tensors = [] | |
| positive_ctx_indices = [] | |
| hard_neg_ctx_indices = [] | |
| for sample in samples: | |
| # ctx+ & [ctx-] composition | |
| # as of now, take the first(gold) ctx+ only | |
| if shuffle and shuffle_positives: | |
| positive_ctxs = sample["positive_ctxs"] | |
| positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] | |
| else: | |
| positive_ctx = sample["positive_ctxs"][0] | |
| if do_lower_fill: | |
| positive_ctx["text"] = positive_ctx["text"].lower() | |
| neg_ctxs = sample["negative_ctxs"] | |
| hard_neg_ctxs = sample["hard_negative_ctxs"] | |
| if do_lower_fill: | |
| neg_ctxs, hard_neg_ctxs = list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, neg_ctxs)), list(map(lambda x: {"text": x["text"].lower(), "title": x["title"]}, hard_neg_ctxs)) | |
| question = normalize_question(sample["question"]) | |
| if shuffle: | |
| random.shuffle(neg_ctxs) | |
| random.shuffle(hard_neg_ctxs) | |
| neg_ctxs = neg_ctxs[0:num_other_negatives] | |
| hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] | |
| all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs | |
| hard_negatives_start_idx = 1 | |
| hard_negatives_end_idx = 1 + len(hard_neg_ctxs) | |
| current_ctxs_len = len(ctx_tensors) | |
| sample_ctxs_tensors = [ | |
| tensorizer.text_to_tensor( | |
| ctx["text"], title=ctx["title"] if insert_title else None | |
| ) | |
| for ctx in all_ctxs | |
| ] | |
| ctx_tensors.extend(sample_ctxs_tensors) | |
| positive_ctx_indices.append(current_ctxs_len) | |
| hard_neg_ctx_indices.append( | |
| [ | |
| i | |
| for i in range( | |
| current_ctxs_len + hard_negatives_start_idx, | |
| current_ctxs_len + hard_negatives_end_idx, | |
| ) | |
| ] | |
| ) | |
| question_tensors.append(tensorizer.text_to_tensor(question)) | |
| ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) | |
| questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) | |
| ctx_segments = torch.zeros_like(ctxs_tensor) | |
| question_segments = torch.zeros_like(questions_tensor) | |
| return BiEncoderBatch( | |
| questions_tensor, | |
| question_segments, | |
| ctxs_tensor, | |
| ctx_segments, | |
| positive_ctx_indices, | |
| hard_neg_ctx_indices, | |
| ) | |
| class BiEncoderNllLoss(object): | |
| def calc( | |
| self, | |
| q_vectors: T, | |
| ctx_vectors: T, | |
| positive_idx_per_question: list, | |
| hard_negatice_idx_per_question: list = None, | |
| ) -> Tuple[T, int]: | |
| """ | |
| Computes nll loss for the given lists of question and ctx vectors. | |
| Note that although hard_negative_idx_per_question in not currently in use, one can use it for the | |
| loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. | |
| :return: a tuple of loss value and amount of correct predictions per batch | |
| """ | |
| scores = self.get_scores(q_vectors, ctx_vectors) | |
| if len(q_vectors.size()) > 1: | |
| q_num = q_vectors.size(0) | |
| scores = scores.view(q_num, -1) | |
| softmax_scores = F.log_softmax(scores, dim=1) | |
| loss = F.nll_loss( | |
| softmax_scores, | |
| torch.tensor(positive_idx_per_question).to(softmax_scores.device), | |
| reduction="mean", | |
| ) | |
| max_score, max_idxs = torch.max(softmax_scores, 1) | |
| correct_predictions_count = ( | |
| max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device) | |
| ).sum() | |
| return loss, correct_predictions_count | |
| def get_scores(q_vector: T, ctx_vectors: T) -> T: | |
| f = BiEncoderNllLoss.get_similarity_function() | |
| return f(q_vector, ctx_vectors) | |
| def get_similarity_function(): | |
| return dot_product_scores |