| """ | |
| Official evaluation script for ReCoRD v1.0. | |
| (Some functions are adopted from the SQuAD evaluation script.) | |
| """ | |
| from __future__ import print_function | |
| from collections import Counter | |
| import string | |
| import re | |
| import argparse | |
| import json | |
| import sys | |
| from ...utils import get_logger | |
| logger=get_logger() | |
| __all__=['normalize_answer', 'evaluate'] | |
| def normalize_answer(s): | |
| """Lower text and remove punctuation, articles and extra whitespace.""" | |
| def remove_articles(text): | |
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |
| def white_space_fix(text): | |
| return ' '.join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return ''.join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def f1_score(prediction, ground_truth): | |
| prediction_tokens = normalize_answer(prediction).split() | |
| ground_truth_tokens = normalize_answer(ground_truth).split() | |
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
| num_same = sum(common.values()) | |
| if num_same == 0: | |
| return 0 | |
| precision = 1.0 * num_same / len(prediction_tokens) | |
| recall = 1.0 * num_same / len(ground_truth_tokens) | |
| f1 = (2 * precision * recall) / (precision + recall) | |
| return f1 | |
| def exact_match_score(prediction, ground_truth): | |
| return normalize_answer(prediction) == normalize_answer(ground_truth) | |
| def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |
| scores_for_ground_truths = [] | |
| for ground_truth in ground_truths: | |
| score = metric_fn(prediction, ground_truth) | |
| scores_for_ground_truths.append(score) | |
| return max(scores_for_ground_truths) | |
| def evaluate(answers, predictions): | |
| f1 = exact_match = total = 0 | |
| correct_ids = [] | |
| for qid in answers: | |
| total += 1 | |
| if qid not in predictions: | |
| message = 'Unanswered question {} will receive score 0.'.format(qid) | |
| logger.warning(message) | |
| continue | |
| ground_truths = list(map(lambda x: x['text'], answers[qid])) | |
| prediction = predictions[qid] | |
| _exact_match = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) | |
| if int(_exact_match) == 1: | |
| correct_ids.append(qid) | |
| exact_match += _exact_match | |
| f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) | |
| exact_match = 100.0 * exact_match / total | |
| f1 = 100.0 * f1 / total | |
| return {'exact_match': exact_match, 'f1': f1} | |