| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Set of utilities for Q&A results validation tasks - Retriver passage |
| | validation and Reader predicted answer validation |
| | """ |
| |
|
| | import collections |
| | import logging |
| | import string |
| | import unicodedata |
| | from functools import partial |
| | from multiprocessing import Pool as ProcessPool |
| | from typing import Tuple, List, Dict |
| |
|
| | import regex as re |
| | from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\ |
| | 'questions_doc_hits']) |
| |
|
| | def calculate_matches(all_docs: Dict[object, Tuple[str, str]], |
| | answers: List[List[str]], closest_docs: List[Tuple[List[object], |
| | List[float]]], workers_num: int, match_type: str) -> QAMatchStats: |
| | """ |
| | Evaluates answers presence in the set of documents. This function is |
| | supposed to be used with a large collection of documents and results. |
| | It internally forks multiple sub-processes for evaluation and then |
| | merges results |
| | :param all_docs: dictionary of the entire documents database. |
| | doc_id -> (doc_text, title) |
| | :param answers: list of answers's list. One list per question |
| | :param closest_docs: document ids of the top results along with their |
| | scores |
| | :param workers_num: amount of parallel threads to process data |
| | :param match_type: type of answer matching. Refer to has_answer code for |
| | available options |
| | :return: matching information tuple. |
| | top_k_hits - a list where the index is the amount of top documents retrieved |
| | and the value is the total amount of valid matches across an entire |
| | dataset. |
| | questions_doc_hits - more detailed info with answer matches for every |
| | question and every retrieved document |
| | """ |
| | global dpr_all_documents |
| | dpr_all_documents = all_docs |
| |
|
| | tok_opts = {} |
| | tokenizer = SimpleTokenizer(**tok_opts) |
| |
|
| | processes = ProcessPool( |
| | processes=workers_num, |
| | ) |
| |
|
| | logger.info('Matching answers in top docs...') |
| |
|
| | get_score_partial = partial(check_answer, match_type=match_type, |
| | tokenizer=tokenizer) |
| |
|
| | questions_answers_docs = zip(answers, closest_docs) |
| |
|
| | scores = processes.map(get_score_partial, questions_answers_docs) |
| |
|
| | logger.info('Per question validation results len=%d', len(scores)) |
| |
|
| | n_docs = len(closest_docs[0][0]) |
| | top_k_hits = [0] * n_docs |
| | for question_hits in scores: |
| | best_hit = next((i for i, x in enumerate(question_hits) if x), None) |
| | if best_hit is not None: |
| | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] |
| |
|
| | return QAMatchStats(top_k_hits, scores) |
| |
|
| |
|
| | def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: |
| | """ |
| | Search through all the top docs to see if they have any of the answers. |
| | """ |
| | answers, (doc_ids, doc_scores) = questions_answers_docs |
| |
|
| | global dpr_all_documents |
| | hits = [] |
| |
|
| | for i, doc_id in enumerate(doc_ids): |
| | doc = dpr_all_documents[doc_id] |
| | text = doc[0] |
| |
|
| | answer_found = False |
| | if text is None: |
| | logger.warning("no doc in db") |
| | hits.append(False) |
| | continue |
| |
|
| | if has_answer(answers, text, tokenizer, match_type): |
| | answer_found = True |
| | hits.append(answer_found) |
| | return hits |
| |
|
| |
|
| | def has_answer(answers, text, tokenizer, match_type) -> bool: |
| | """ |
| | Check if a document contains an answer string. |
| | If `match_type` is string, token matching is done between the text |
| | and answer. |
| | If `match_type` is regex, we search the whole text with the regex. |
| | """ |
| | text = _normalize(text) |
| |
|
| | if match_type == 'string': |
| | |
| | text = tokenizer.tokenize(text).words(uncased=True) |
| |
|
| | for single_answer in answers: |
| | single_answer = _normalize(single_answer) |
| | single_answer = tokenizer.tokenize(single_answer) |
| | single_answer = single_answer.words(uncased=True) |
| |
|
| | for i in range(0, len(text) - len(single_answer) + 1): |
| | if single_answer == text[i: i + len(single_answer)]: |
| | return True |
| |
|
| | elif match_type == 'regex': |
| | |
| | for single_answer in answers: |
| | single_answer = _normalize(single_answer) |
| | if regex_match(text, single_answer): |
| | return True |
| | return False |
| |
|
| |
|
| | def regex_match(text, pattern): |
| | """Test if a regex pattern is contained within a text.""" |
| | try: |
| | pattern = re.compile( |
| | pattern, |
| | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, |
| | ) |
| | except BaseException: |
| | return False |
| | return pattern.search(text) is not None |
| |
|
| |
|
| | |
| | def exact_match_score(prediction, ground_truth): |
| | return _normalize_answer(prediction) == _normalize_answer(ground_truth) |
| |
|
| |
|
| | def _normalize_answer(s): |
| | def remove_articles(text): |
| | return re.sub(r'\b(a|an|the)\b', ' ', text) |
| |
|
| | def white_space_fix(text): |
| | return ' '.join(text.split()) |
| |
|
| | def remove_punc(text): |
| | exclude = set(string.punctuation) |
| | return ''.join(ch for ch in text if ch not in exclude) |
| |
|
| | def lower(text): |
| | return text.lower() |
| |
|
| | return white_space_fix(remove_articles(remove_punc(lower(s)))) |
| |
|
| |
|
| | def _normalize(text): |
| | return unicodedata.normalize('NFD', text) |
| |
|