correcting metric info
Browse files- anls.py +13 -35
- compute_score.py +6 -11
anls.py
CHANGED
|
@@ -41,21 +41,17 @@ _KWARGS_DESCRIPTION = """
|
|
| 41 |
Computes Average Normalized Levenshtein Similarity (ANLS).
|
| 42 |
Args:
|
| 43 |
predictions: List of question-answers dictionaries with the following key-values:
|
| 44 |
-
- '
|
| 45 |
- 'prediction_text': the text of the answer
|
| 46 |
references: List of question-answers dictionaries with the following key-values:
|
| 47 |
-
- '
|
| 48 |
-
- 'answers':
|
| 49 |
-
|
| 50 |
-
'text': list of possible texts for the answer, as a list of strings
|
| 51 |
-
'answer_start': list of start positions for the answer, as a list of ints
|
| 52 |
-
}
|
| 53 |
-
Note that answer_start values are not taken into account to compute the metric.
|
| 54 |
Returns:
|
| 55 |
'anls': The ANLS score of predicted tokens versus the gold answer
|
| 56 |
Examples:
|
| 57 |
>>> predictions = [{'prediction_text': 'Denver Broncos', 'question_id': '56e10a3be3433e1400422b22'}]
|
| 58 |
-
>>> references = [{'answers': ['Denver Broncos', 'Denver R. Broncos']
|
| 59 |
>>> anls_metric = evaluate.load("anls")
|
| 60 |
>>> results = anls_metric.compute(predictions=predictions, references=references)
|
| 61 |
>>> print(results)
|
|
@@ -72,36 +68,18 @@ class Anls(evaluate.Metric):
|
|
| 72 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 73 |
features=datasets.Features(
|
| 74 |
{
|
| 75 |
-
"predictions": {"
|
|
|
|
| 76 |
"references": {
|
| 77 |
-
"
|
| 78 |
-
"answers": datasets.features.Sequence(
|
| 79 |
-
{
|
| 80 |
-
"text": datasets.Value("string"),
|
| 81 |
-
"answer_start": datasets.Value("int32"),
|
| 82 |
-
}
|
| 83 |
-
),
|
| 84 |
},
|
| 85 |
}
|
| 86 |
)
|
| 87 |
)
|
| 88 |
|
| 89 |
def _compute(self, predictions, references):
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
{
|
| 95 |
-
"qas": [
|
| 96 |
-
{
|
| 97 |
-
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
|
| 98 |
-
"id": ref["id"],
|
| 99 |
-
}
|
| 100 |
-
for ref in references
|
| 101 |
-
]
|
| 102 |
-
}
|
| 103 |
-
]
|
| 104 |
-
}
|
| 105 |
-
]
|
| 106 |
-
score = compute_score(dataset=dataset, predictions=prediction_dict)
|
| 107 |
-
return score
|
|
|
|
| 41 |
Computes Average Normalized Levenshtein Similarity (ANLS).
|
| 42 |
Args:
|
| 43 |
predictions: List of question-answers dictionaries with the following key-values:
|
| 44 |
+
- 'question_id': id of the question-answer pair as given in the references (see below)
|
| 45 |
- 'prediction_text': the text of the answer
|
| 46 |
references: List of question-answers dictionaries with the following key-values:
|
| 47 |
+
- 'question_id': id of the question-answer pair (see above),
|
| 48 |
+
- 'answers': list of possible texts for the answer, as a list of strings
|
| 49 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
Returns:
|
| 51 |
'anls': The ANLS score of predicted tokens versus the gold answer
|
| 52 |
Examples:
|
| 53 |
>>> predictions = [{'prediction_text': 'Denver Broncos', 'question_id': '56e10a3be3433e1400422b22'}]
|
| 54 |
+
>>> references = [{'answers': ['Denver Broncos', 'Denver R. Broncos'], 'question_id': '56e10a3be3433e1400422b22'}]
|
| 55 |
>>> anls_metric = evaluate.load("anls")
|
| 56 |
>>> results = anls_metric.compute(predictions=predictions, references=references)
|
| 57 |
>>> print(results)
|
|
|
|
| 68 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 69 |
features=datasets.Features(
|
| 70 |
{
|
| 71 |
+
"predictions": {"question_id": datasets.Value("string"),
|
| 72 |
+
"prediction_text": datasets.Value("string")},
|
| 73 |
"references": {
|
| 74 |
+
"question_id": datasets.Value("string"),
|
| 75 |
+
"answers": datasets.features.Sequence(datasets.Value("string")),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
},
|
| 77 |
}
|
| 78 |
)
|
| 79 |
)
|
| 80 |
|
| 81 |
def _compute(self, predictions, references):
|
| 82 |
+
ground_truths = {x['question_id']: x['answers'] for x in references}
|
| 83 |
+
predictions = {x['question_id']: x['prediction_text'] for x in predictions}
|
| 84 |
+
anls_score = compute_score(predictions=predictions, ground_truths=ground_truths)
|
| 85 |
+
return {"anls_score": anls_score}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compute_score.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from Levenshtein import ratio
|
| 2 |
|
| 3 |
|
| 4 |
-
def
|
| 5 |
theta = 0.5
|
| 6 |
anls_score = 0
|
| 7 |
for qid, prediction in predictions.items():
|
|
@@ -18,20 +18,15 @@ def anls_compute(predictions, ground_truths):
|
|
| 18 |
return anls_score
|
| 19 |
|
| 20 |
|
| 21 |
-
def compute_score(dataset, prediction):
|
| 22 |
-
ground_truths = {x['question_id']: x['answers'] for x in dataset}
|
| 23 |
-
predictions = {x['question_id']: x['prediction_text'] for x in prediction}
|
| 24 |
-
anls_score = anls_compute(predictions=predictions, ground_truths=ground_truths)
|
| 25 |
-
return {"anls_score": anls_score}
|
| 26 |
-
|
| 27 |
-
|
| 28 |
if __name__ == "__main__":
|
| 29 |
-
|
| 30 |
{'question_id': '18601', 'prediction_text': '12/15/89'},
|
| 31 |
{'question_id': '16734', 'prediction_text': 'Dear dr. Lobo'}]
|
| 32 |
|
| 33 |
-
|
| 34 |
{'answers': ['12/15/88'], 'question_id': '18601'},
|
| 35 |
{'answers': ['Dear Dr. Lobo', 'Dr. Lobo'], 'question_id': '16734'}]
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
print(anls_score)
|
|
|
|
| 1 |
from Levenshtein import ratio
|
| 2 |
|
| 3 |
|
| 4 |
+
def compute_score(predictions, ground_truths):
|
| 5 |
theta = 0.5
|
| 6 |
anls_score = 0
|
| 7 |
for qid, prediction in predictions.items():
|
|
|
|
| 18 |
return anls_score
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
if __name__ == "__main__":
|
| 22 |
+
predictions = [{'question_id': '10285', 'prediction_text': 'Denver Broncos'},
|
| 23 |
{'question_id': '18601', 'prediction_text': '12/15/89'},
|
| 24 |
{'question_id': '16734', 'prediction_text': 'Dear dr. Lobo'}]
|
| 25 |
|
| 26 |
+
references = [{"answers": ["Denver Broncos", "Denver R. Broncos"], 'question_id': '10285'},
|
| 27 |
{'answers': ['12/15/88'], 'question_id': '18601'},
|
| 28 |
{'answers': ['Dear Dr. Lobo', 'Dr. Lobo'], 'question_id': '16734'}]
|
| 29 |
+
ground_truths = {x['question_id']: x['answers'] for x in references}
|
| 30 |
+
predictions = {x['question_id']: x['prediction_text'] for x in predictions}
|
| 31 |
+
anls_score = compute_score(predictions=predictions, ground_truths=ground_truths)
|
| 32 |
print(anls_score)
|