File size: 6,082 Bytes
71781d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2895f83
71781d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2021 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""FEVER (Fact Extraction and VERification) metric."""

import datasets

import evaluate


_CITATION = """\
@inproceedings{thorne2018fever,
  title={FEVER: Fact Extraction and VERification},
  author={Thorne, James and Vlachos, Andreas and Christodoulopoulos, Christos and Mittal, Arpit},
  booktitle={Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
  pages={809--819},
  year={2018}
}
"""
_DESCRIPTION = """\
The FEVER (Fact Extraction and VERification) metric evaluates the performance of systems that verify factual claims against evidence retrieved from Wikipedia.

It consists of three main components:
- **Label accuracy**: measures how often the predicted claim label (SUPPORTED, REFUTED, or NOT ENOUGH INFO) matches the gold label.
- **FEVER score**: considers a prediction correct only if the label is correct *and* at least one complete gold evidence set is retrieved.
- **Evidence F1**: computes the micro-averaged precision, recall, and F1 between predicted and gold evidence sentences.

The FEVER score is the official leaderboard metric used in the FEVER shared tasks.
"""
_KWARGS_DESCRIPTION = """
Computes the FEVER evaluation metrics.

Args:
    predictions (list of dict): Each prediction should be a dictionary with:
        - "label" (str): the predicted claim label.
        - "evidence" (list of str): the predicted evidence sentences.
    references (list of dict): Each reference should be a dictionary with:
        - "label" (str): the gold claim label.
        - "evidence_sets" (list of list of str): all possible gold evidence sets.

Returns:
    A dictionary containing:
        - 'label_accuracy': proportion of claims with correctly predicted labels.
        - 'fever_score': proportion of claims where both the label and at least one full gold evidence set are correct.
        - 'evidence_precision': micro-averaged precision of evidence retrieval.
        - 'evidence_recall': micro-averaged recall of evidence retrieval.
        - 'evidence_f1': micro-averaged F1 of evidence retrieval.

Example:
    >>> predictions = [{"label": "SUPPORTED", "evidence": ["E1", "E2"]}]
    >>> references = [{"label": "SUPPORTED", "evidence_sets": [["E1", "E2"], ["E3", "E4"]]}]
    >>> fever = evaluate.load("fever")
    >>> results = fever.compute(predictions=predictions, references=references)
    >>> print(results["label_accuracy"])
    1.0
    >>> print(results["fever_score"])
    1.0
    >>> print(results["evidence_precision"])
    1.0
    >>> print(results["evidence_recall"])
    0.5
    >>> print(round(results["evidence_f1"], 3))
    0.667
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class FEVER(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": {
                        "label": datasets.Value("string"),
                        "evidence": datasets.Sequence(datasets.Value("string")),
                    },
                    "references": {
                        "label": datasets.Value("string"),
                        "evidence_sets": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
                    },
                }
            ),
            reference_urls=[
                "https://fever.ai/dataset/",
                "https://arxiv.org/abs/1803.05355",
            ],
        )

    def _compute(self, predictions, references):
        """
        Computes FEVER metrics:
        - Label accuracy
        - FEVER score (label + complete evidence set)
        - Evidence precision, recall, and F1 (micro-averaged)
        """
        total = len(predictions)
        label_correct, fever_correct = 0, 0
        total_overlap, total_pred, total_gold = 0, 0, 0

        for pred, ref in zip(predictions, references):
            pred_label = pred["label"]
            pred_evidence = set(e.strip().lower() for e in pred["evidence"])
            gold_label = ref["label"]
            gold_sets = []
            for s in ref["evidence_sets"]:
                gold_sets.append([e.strip().lower() for e in s])

            if pred_label == gold_label:
                label_correct += 1
                for g_set in gold_sets:
                    if set(g_set).issubset(pred_evidence):
                        fever_correct += 1
                        break

            gold_evidence = set().union(*gold_sets) if gold_sets else set()
            overlap = len(gold_evidence.intersection(pred_evidence))
            total_overlap += overlap
            total_pred += len(pred_evidence)
            total_gold += len(gold_evidence)

        precision = (total_overlap / total_pred) if total_pred else 0
        recall = (total_overlap / total_gold) if total_gold else 0
        evidence_f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        fever_score = fever_correct / total if total else 0
        label_accuracy = label_correct / total if total else 0

        return {
            "label_accuracy": label_accuracy,
            "fever_score": fever_score,
            "evidence_precision": precision,
            "evidence_recall": recall,
            "evidence_f1": evidence_f1,
        }