File size: 2,038 Bytes
edd0a90
6bb843b
fb49a5d
 
 
 
 
 
 
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb49a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb843b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from timebench_eval import TimebenchEval
import pytest
from conftest import (
    PREDICTION_1,
    PREDICTION_2,
    PREDICTION_3,
    PREDICTION_4,
    PREDICTION_5,
)


@pytest.mark.parametrize(
    "prediction,reference,task,expected_metrics",
    [
        (
            PREDICTION_1,
            "Troyes AC",
            "TempReason",
            {
                "exact_match": [1],
                "f1": [1],
            },
        ),
        (
            PREDICTION_2,
            "Aug, 1804",
            "Date Arithmetic",
            {
                "exact_match": [1],
            },
        ),
        (
            PREDICTION_3,
            "unanswerable",
            "MenatQA",
            {
                "exact_match": [1],
                "f1": [1],
            },
        ),
        (
            PREDICTION_4,
            "Cardiff City",
            "MenatQA",
            {
                "exact_match": [1],
                "f1": [1],
            },
        ),
        (
            PREDICTION_5,
            "B. No more than ten minutes && C. No more than five minutes",
            "TimeDial",
            {
                "exact_match": [1],
                "f1": [1],
            },
        ),
        (
            PREDICTION_5,
            "B.",
            "TimeDial",
            {
                "exact_match": [0],
                "f1": [pytest.approx(2 / 3, rel=1e-6)],
            },
        ),
        (
            PREDICTION_5,
            "A.",
            "TimeDial",
            {
                "exact_match": [0],
                "f1": [0],
            },
        ),
    ],
)
def test_eval(prediction, reference, task, expected_metrics):
    metrics = TimebenchEval()._compute([prediction], [reference], task)
    assert metrics == expected_metrics


def test_eval_many():
    metrics = TimebenchEval()._compute(
        [PREDICTION_3, PREDICTION_4], ["unanswerable", "Cardiff City"], "MenatQA"
    )
    assert metrics == {
        "exact_match": [1, 1],
        "f1": [1, 1],
    }