File size: 4,830 Bytes
6bef416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import pytest

from src.data import DataBundle, DataContractError, validate_bundle


def _candidate(bundle, *, eval_runs=None, retrieval_events=None, documents=None, chunks=None, scenarios=None) -> DataBundle:
    return DataBundle(
        eval_runs=eval_runs if eval_runs is not None else bundle.eval_runs.copy(),
        retrieval_events=retrieval_events if retrieval_events is not None else bundle.retrieval_events.copy(),
        documents=documents if documents is not None else bundle.documents.copy(),
        chunks=chunks if chunks is not None else bundle.chunks.copy(),
        scenarios=scenarios if scenarios is not None else bundle.scenarios.copy(),
        dictionary=bundle.dictionary.copy(),
    )


def test_validate_bundle_rejects_negative_cost(bundle) -> None:
    eval_runs = bundle.eval_runs.copy()
    eval_runs.loc[eval_runs.index[0], "total_cost_usd"] = -0.01

    with pytest.raises(DataContractError, match="total_cost_usd"):
        validate_bundle(_candidate(bundle, eval_runs=eval_runs))


def test_validate_bundle_rejects_rank_zero(bundle) -> None:
    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events.loc[retrieval_events.index[0], "rank"] = 0

    with pytest.raises(DataContractError, match="rank"):
        validate_bundle(_candidate(bundle, retrieval_events=retrieval_events))


def test_validate_bundle_rejects_fractional_rank(bundle) -> None:
    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events["rank"] = retrieval_events["rank"].astype(float)
    retrieval_events.loc[retrieval_events.index[0], "rank"] = 1.5

    with pytest.raises(DataContractError, match="integer"):
        validate_bundle(_candidate(bundle, retrieval_events=retrieval_events))


def test_validate_bundle_rejects_invalid_relevance_flag(bundle) -> None:
    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events.loc[retrieval_events.index[0], "is_relevant"] = 2

    with pytest.raises(DataContractError, match="is_relevant"):
        validate_bundle(_candidate(bundle, retrieval_events=retrieval_events))


def test_validate_bundle_rejects_non_numeric_required_eval_metric(bundle) -> None:
    eval_runs = bundle.eval_runs.copy()
    eval_runs["is_correct"] = eval_runs["is_correct"].astype("object")
    eval_runs.loc[eval_runs.index[0], "is_correct"] = "bad"

    with pytest.raises(DataContractError, match="is_correct.*non-numeric"):
        validate_bundle(_candidate(bundle, eval_runs=eval_runs))


def test_validate_bundle_rejects_missing_required_numeric_metric(bundle) -> None:
    eval_runs = bundle.eval_runs.copy()
    eval_runs.loc[eval_runs.index[0], "recall_at_10"] = None

    with pytest.raises(DataContractError, match="recall_at_10.*missing numeric"):
        validate_bundle(_candidate(bundle, eval_runs=eval_runs))


def test_validate_bundle_rejects_non_numeric_retrieval_rank(bundle) -> None:
    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events["rank"] = retrieval_events["rank"].astype("object")
    retrieval_events.loc[retrieval_events.index[0], "rank"] = "first"

    with pytest.raises(DataContractError, match="rank.*non-numeric"):
        validate_bundle(_candidate(bundle, retrieval_events=retrieval_events))


def test_validate_bundle_rejects_missing_primary_key(bundle) -> None:
    eval_runs = bundle.eval_runs.copy()
    eval_runs.loc[eval_runs.index[0], "example_id"] = None

    with pytest.raises(DataContractError, match="example_id.*missing"):
        validate_bundle(_candidate(bundle, eval_runs=eval_runs))


def test_validate_bundle_rejects_missing_required_foreign_key(bundle) -> None:
    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events.loc[retrieval_events.index[0], "chunk_id"] = ""

    with pytest.raises(DataContractError, match="chunk_id.*missing"):
        validate_bundle(_candidate(bundle, retrieval_events=retrieval_events))


def test_standardize_eval_rejects_non_numeric_optional_metric(bundle) -> None:
    from src.data import standardize_eval

    eval_runs = bundle.eval_runs.copy()
    eval_runs["top1_score"] = eval_runs["top1_score"].astype("object")
    eval_runs.loc[eval_runs.index[0], "top1_score"] = "not-a-score"

    with pytest.raises(DataContractError, match="top1_score.*non-numeric"):
        standardize_eval(eval_runs, bundle.scenarios)


def test_standardize_retrieval_events_rejects_blank_required_rank(bundle) -> None:
    from src.data import standardize_retrieval_events

    retrieval_events = bundle.retrieval_events.copy()
    retrieval_events["rank"] = retrieval_events["rank"].astype("object")
    retrieval_events.loc[retrieval_events.index[0], "rank"] = " "

    with pytest.raises(DataContractError, match="rank.*missing numeric"):
        standardize_retrieval_events(retrieval_events)