Spaces:
Build error
Build error
| # Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
| # All rights reserved. | |
| # See the top-level LICENSE and NOTICE files for details. | |
| # LLNL-CODE-838964 | |
| # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
| from sentence_transformers.cross_encoder import CrossEncoder as CE | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| import json | |
| from collections import defaultdict | |
| import os | |
| class CrossEncoder: | |
| def __init__(self, | |
| model_path: str = None, | |
| max_length: int = None, | |
| **kwargs): | |
| if max_length != None: | |
| self.model = CE(model_path, max_length = max_length, **kwargs) | |
| self.model = CE(model_path, **kwargs) | |
| def predict(self, | |
| sentences: List[Tuple[str, str]], | |
| batch_size: int = 32, | |
| show_progress_bar: bool = False) -> List[float]: | |
| return self.model.predict(sentences = sentences, | |
| batch_size = batch_size, | |
| show_progress_bar = show_progress_bar) | |
| class CERank: | |
| def __init__(self, model, batch_size: int =128, **kwargs): | |
| self.cross_encoder = model | |
| self.batch_size = batch_size | |
| def flatten_examples(self, contexts: Dict[str, Dict], question: str): | |
| text_pairs, pair_ids = [], [] | |
| for context_id, context in contexts.items(): | |
| pair_ids.append(['question_0', context_id]) | |
| text_pairs.append([question, context['text']]) | |
| return text_pairs, pair_ids | |
| def group_questionrank(self, pair_ids, rank_scores): | |
| unsorted = defaultdict(list) | |
| for pair, score in zip(pair_ids, rank_scores): | |
| query_id, paragraph_id = pair[0], pair[1] | |
| unsorted[query_id].append((paragraph_id, score)) | |
| return unsorted | |
| def get_rankings(self, pair_ids, rank_scores, text_pairs): | |
| unsorted_ranks = self.group_questionrank(pair_ids, rank_scores) | |
| rankings = defaultdict(dict) | |
| for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()): | |
| sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True) | |
| sorted_ranks, scores = list(zip(*sort_ranks)) | |
| rankings[query_id]['text'] = text_pairs[idx][0] | |
| rankings[query_id]['scores'] = list(scores) | |
| rankings[query_id]['ranks'] = list(sorted_ranks) | |
| return rankings | |
| def rank(self, | |
| contexts: Dict[str, Dict], | |
| question: str): | |
| text_pairs, pair_ids = self.flatten_examples(contexts, question) | |
| rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)] | |
| full_results = self.get_rankings(pair_ids, rank_scores, text_pairs) | |
| return full_results | |
| def get_ranked_contexts(context_json, question): | |
| dirname = 'examples' | |
| model_path = 'ms-marco-electra-base' | |
| max_length = 512 | |
| # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. | |
| cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) | |
| ranker = CERank(cross_encoder) | |
| with open(context_json, 'r') as fin: | |
| contexts = json.load(fin) | |
| rankings = ranker.rank(contexts, question) | |
| with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout: | |
| json.dump(rankings, fout) | |
| def get_ranked_contexts_in_memory(contexts, question): | |
| dirname = 'examples' | |
| model_path = 'ms-marco-electra-base' | |
| max_length = 512 | |
| # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. | |
| cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) | |
| ranker = CERank(cross_encoder) | |
| rankings = ranker.rank(contexts, question) | |
| return rankings | |