| from minicheck_web.minicheck import MiniCheck |
| from web_retrieval import * |
| from nltk.tokenize import sent_tokenize |
| import evaluate |
|
|
|
|
| def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk): |
| ''' |
| Sort the chunks in a single document based on the probability of "supported" in descending order. |
| This function is used when a user document is provided. |
| ''' |
|
|
| flattened_docs = [doc for chunk in used_chunk for doc in chunk] |
| flattened_scores = [score for chunk in support_prob_per_chunk for score in chunk] |
|
|
| doc_score = list(zip(flattened_docs, flattened_scores)) |
| ranked_doc_score = sorted(doc_score, key=lambda x: x[1], reverse=True) |
|
|
| ranked_docs, scores = zip(*ranked_doc_score) |
|
|
| return ranked_docs, scores |
| |
|
|
| class EndpointHandler(): |
| def __init__(self, path="./"): |
| self.scorer = MiniCheck(path=path) |
| self.rouge = evaluate.load('rouge') |
|
|
|
|
| def __call__(self, data): |
|
|
| claim = data['inputs']['claims'][0] |
|
|
| |
| if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '': |
| _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=data) |
| ranked_docs, scores = sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk) |
|
|
| span_to_highlight = [] |
| for doc_chunk, score in zip(ranked_docs, scores): |
| |
| if score > 0.5: |
| highest_score_sent, _ = self.chunk_and_highest_rouge_score(doc_chunk, claim) |
| span_to_highlight.append(highest_score_sent) |
| else: |
| span_to_highlight.append("") |
| |
| outputs = { |
| 'ranked_docs': ranked_docs, |
| 'scores': scores, |
| 'span_to_highlight': span_to_highlight |
| } |
| |
| else: |
| assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version." |
|
|
| ranked_docs, scores, ranked_urls = self.search_relevant_docs(claim) |
|
|
| span_to_highlight = [] |
| for doc_chunk, score in zip(ranked_docs, scores): |
| |
| if score > 0.5: |
| highest_score_sent, _ = self.chunk_and_highest_rouge_score(doc_chunk, claim) |
| span_to_highlight.append(highest_score_sent) |
| else: |
| span_to_highlight.append("") |
|
|
| outputs = { |
| 'ranked_docs': ranked_docs, |
| 'scores': scores, |
| 'ranked_urls': ranked_urls, |
| 'span_to_highlight': span_to_highlight |
| } |
| |
| return outputs |
| |
| |
| def search_relevant_docs(self, claim, timeout=10, max_search_results_per_query=5, allow_duplicated_urls=False): |
|
|
| search_results = search_google(claim, timeout=timeout) |
|
|
| print('Searching webpages...') |
| start = time() |
| with concurrent.futures.ThreadPoolExecutor() as e: |
| scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout)) |
| end = time() |
| print(f"Finished searching in {round((end - start), 1)} seconds.\n") |
| scraped_results = [(r[0][:20000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]] |
|
|
| retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query]) |
|
|
| print('Scoring webpages...') |
| start = time() |
| retrieved_data = { |
| 'inputs': { |
| 'docs': list(retrieved_docs), |
| 'claims': [claim]*len(retrieved_docs) |
| } |
| } |
| _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data) |
| end = time() |
| num_chunks = len([item for items in used_chunk for item in items]) |
| print(f'Finished {num_chunks} entailment checks in {round((end - start), 1)} seconds ({round(num_chunks / (end - start) * 60)} Doc./min).') |
|
|
| ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls) |
|
|
| return ranked_docs, scores, ranked_urls |
| |
|
|
| def chunk_and_highest_rouge_score(self, doc, claim): |
|
|
| ''' |
| Given a document and a claim, return the sentence with the highest rouge score and the score |
| ''' |
|
|
| doc_sentences = sent_tokenize(doc) |
| claims = [claim] * len(doc_sentences) |
|
|
| results = self.rouge.compute( |
| predictions=doc_sentences, |
| references=claims, |
| use_aggregator=False) |
|
|
| highest_score = 0 |
| highest_score_sent = "" |
| for i in range(len(doc_sentences)): |
| if results['rouge1'][i] > highest_score: |
| highest_score = results['rouge1'][i] |
| highest_score_sent = doc_sentences[i] |
| |
| return highest_score_sent, highest_score |