File size: 4,267 Bytes
42bd089 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
import argparse
import json
import re
import string
import sys
from collections import Counter
def remove_punc(tokens):
exclude = (
"!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
)
exclude += string.punctuation
exclude = [*exclude]
return [tok for tok in tokens if tok not in exclude]
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
import emoji
import neologdn
def white_space_fix(text):
return " ".join(text.split())
def remove_emoji(text):
text = "".join(["" if emoji.is_emoji(c) else c for c in text])
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # emoticons
"\U0001F300-\U0001F5FF" # symbols & pictographs
"\U0001F680-\U0001F6FF" # transport & map symbols
"\U0001F1E0-\U0001F1FF" # flags (iOS)
"\U00002702-\U000027B0"
"]+",
flags=re.UNICODE,
)
return emoji_pattern.sub(r"", text)
return white_space_fix((neologdn.normalize(remove_emoji(s))))
def f1_score(prediction, ground_truth):
from fugashi import Tagger
tagger = Tagger("-Owakati")
prediction_tokens = remove_punc(tagger.parse(normalize_answer(prediction)).split())
ground_truth_tokens = remove_punc(
tagger.parse(normalize_answer(ground_truth)).split()
)
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article["paragraphs"]:
for qa in paragraph["qas"]:
total += 1
if qa["id"] not in predictions:
message = (
"Unanswered question " + qa["id"] + " will receive score 0."
)
print(message, file=sys.stderr)
continue
ground_truths = [x["text"] for x in qa["answers"]]
prediction = predictions[qa["id"]]
exact_match += metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths
)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {"exact_match": exact_match, "f1": f1}
if __name__ == "__main__":
expected_version = "1.1"
parser = argparse.ArgumentParser(
description="Evaluation for Japanese SQuAD " + expected_version
)
parser.add_argument("dataset_file", help="Dataset file")
parser.add_argument("prediction_file", help="Prediction File")
args = parser.parse_args()
with open(args.dataset_file) as dataset_file:
dataset_json = json.load(dataset_file)
if dataset_json["version"] != expected_version:
print(
"Evaluation expects v-"
+ expected_version
+ ", but got dataset with v-"
+ dataset_json["version"],
file=sys.stderr,
)
dataset = dataset_json["data"]
with open(args.prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))
|