vxa8502 commited on
Commit
cfb4413
·
1 Parent(s): 866804f

Harden eval loader

Browse files
Files changed (2) hide show
  1. sage/data/eval.py +59 -9
  2. tests/test_eval.py +190 -0
sage/data/eval.py CHANGED
@@ -20,17 +20,67 @@ def load_eval_cases(filename: str) -> list[EvalCase]:
20
 
21
  Returns:
22
  List of EvalCase objects.
 
 
 
 
23
  """
24
  filepath = EVAL_DIR / filename
25
 
26
- with open(filepath) as f:
27
- data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- return [
30
- EvalCase(
31
- query=d["query"],
32
- relevant_items=d["relevant_items"],
33
- user_id=d.get("user_id"),
34
  )
35
- for d in data
36
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  Returns:
22
  List of EvalCase objects.
23
+
24
+ Raises:
25
+ FileNotFoundError: If the file does not exist.
26
+ ValueError: If JSON is invalid or data fails validation.
27
  """
28
  filepath = EVAL_DIR / filename
29
 
30
+ # Load JSON with helpful error messages
31
+ try:
32
+ with open(filepath) as f:
33
+ data = json.load(f)
34
+ except FileNotFoundError:
35
+ raise FileNotFoundError(f"Evaluation file not found: {filepath}")
36
+ except json.JSONDecodeError as e:
37
+ raise ValueError(
38
+ f"Invalid JSON format in evaluation file: {filepath} "
39
+ f"(line {e.lineno}, column {e.colno})"
40
+ )
41
+
42
+ # Handle empty list gracefully
43
+ if not data:
44
+ return []
45
 
46
+ # Validate structure is a list
47
+ if not isinstance(data, list):
48
+ raise ValueError(
49
+ f"Evaluation file must contain a JSON array, got {type(data).__name__}: "
50
+ f"{filepath}"
51
  )
52
+
53
+ # Validate each case
54
+ cases = []
55
+ for i, d in enumerate(data):
56
+ # Check required fields
57
+ if "query" not in d:
58
+ raise ValueError(f"Missing 'query' field in case {i}: {filepath}")
59
+ if "relevant_items" not in d:
60
+ raise ValueError(f"Missing 'relevant_items' field in case {i}: {filepath}")
61
+
62
+ # Validate relevant_items is a dict
63
+ relevant_items = d["relevant_items"]
64
+ if not isinstance(relevant_items, dict):
65
+ raise ValueError(
66
+ f"'relevant_items' must be a dict in case {i}, "
67
+ f"got {type(relevant_items).__name__}: {filepath}"
68
+ )
69
+
70
+ # Validate relevance scores are numeric
71
+ for product_id, score in relevant_items.items():
72
+ if not isinstance(score, (int, float)):
73
+ raise ValueError(
74
+ f"Relevance score for '{product_id}' must be numeric in case {i}, "
75
+ f"got {type(score).__name__}: '{score}'"
76
+ )
77
+
78
+ cases.append(
79
+ EvalCase(
80
+ query=d["query"],
81
+ relevant_items=relevant_items,
82
+ user_id=d.get("user_id"),
83
+ )
84
+ )
85
+
86
+ return cases
tests/test_eval.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for sage.data.eval — evaluation dataset loading utilities."""
2
+
3
+ import json
4
+
5
+ import pytest
6
+
7
+ from sage.data.eval import load_eval_cases
8
+
9
+
10
+ class TestLoadEvalCases:
11
+ """Tests for load_eval_cases function."""
12
+
13
+ def test_valid_cases(self, tmp_path, monkeypatch):
14
+ """Valid JSON file returns list of EvalCase objects."""
15
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
16
+
17
+ data = [
18
+ {
19
+ "query": "wireless headphones",
20
+ "relevant_items": {"B001": 3.0, "B002": 2.0},
21
+ },
22
+ {"query": "bluetooth speaker", "relevant_items": {"B003": 1.0}},
23
+ ]
24
+ (tmp_path / "test.json").write_text(json.dumps(data))
25
+
26
+ cases = load_eval_cases("test.json")
27
+
28
+ assert len(cases) == 2
29
+ assert cases[0].query == "wireless headphones"
30
+ assert cases[0].relevant_items == {"B001": 3.0, "B002": 2.0}
31
+ assert cases[1].query == "bluetooth speaker"
32
+
33
+ def test_empty_list_returns_empty(self, tmp_path, monkeypatch):
34
+ """Empty JSON array returns empty list without error."""
35
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
36
+
37
+ (tmp_path / "empty.json").write_text("[]")
38
+
39
+ cases = load_eval_cases("empty.json")
40
+
41
+ assert cases == []
42
+
43
+ def test_file_not_found_raises_clear_error(self, tmp_path, monkeypatch):
44
+ """Missing file raises FileNotFoundError with filepath context."""
45
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
46
+
47
+ with pytest.raises(FileNotFoundError, match="Evaluation file not found"):
48
+ load_eval_cases("nonexistent.json")
49
+
50
+ def test_invalid_json_raises_clear_error(self, tmp_path, monkeypatch):
51
+ """Invalid JSON raises ValueError with line/column info."""
52
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
53
+
54
+ (tmp_path / "bad.json").write_text("{invalid json")
55
+
56
+ with pytest.raises(ValueError, match="Invalid JSON format"):
57
+ load_eval_cases("bad.json")
58
+
59
+ def test_not_array_raises_error(self, tmp_path, monkeypatch):
60
+ """JSON object (not array) raises ValueError."""
61
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
62
+
63
+ (tmp_path / "object.json").write_text('{"query": "test"}')
64
+
65
+ with pytest.raises(ValueError, match="must contain a JSON array"):
66
+ load_eval_cases("object.json")
67
+
68
+ def test_missing_query_first_case(self, tmp_path, monkeypatch):
69
+ """Missing query in first case raises ValueError with index."""
70
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
71
+
72
+ data = [{"relevant_items": {"B001": 1.0}}]
73
+ (tmp_path / "test.json").write_text(json.dumps(data))
74
+
75
+ with pytest.raises(ValueError, match="Missing 'query' field in case 0"):
76
+ load_eval_cases("test.json")
77
+
78
+ def test_missing_query_later_case(self, tmp_path, monkeypatch):
79
+ """Missing query in later case raises ValueError with correct index."""
80
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
81
+
82
+ data = [
83
+ {"query": "valid", "relevant_items": {"B001": 1.0}},
84
+ {"query": "also valid", "relevant_items": {"B002": 2.0}},
85
+ {"relevant_items": {"B003": 3.0}}, # Missing query at index 2
86
+ ]
87
+ (tmp_path / "test.json").write_text(json.dumps(data))
88
+
89
+ with pytest.raises(ValueError, match="Missing 'query' field in case 2"):
90
+ load_eval_cases("test.json")
91
+
92
+ def test_missing_relevant_items(self, tmp_path, monkeypatch):
93
+ """Missing relevant_items raises ValueError with index."""
94
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
95
+
96
+ data = [{"query": "test query"}]
97
+ (tmp_path / "test.json").write_text(json.dumps(data))
98
+
99
+ with pytest.raises(
100
+ ValueError, match="Missing 'relevant_items' field in case 0"
101
+ ):
102
+ load_eval_cases("test.json")
103
+
104
+ def test_relevant_items_not_dict(self, tmp_path, monkeypatch):
105
+ """relevant_items as list raises ValueError."""
106
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
107
+
108
+ data = [{"query": "test", "relevant_items": ["B001", "B002"]}]
109
+ (tmp_path / "test.json").write_text(json.dumps(data))
110
+
111
+ with pytest.raises(ValueError, match="'relevant_items' must be a dict"):
112
+ load_eval_cases("test.json")
113
+
114
+ def test_relevance_score_not_numeric(self, tmp_path, monkeypatch):
115
+ """Non-numeric relevance score raises ValueError with product ID."""
116
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
117
+
118
+ data = [{"query": "test", "relevant_items": {"B001": "high"}}]
119
+ (tmp_path / "test.json").write_text(json.dumps(data))
120
+
121
+ with pytest.raises(
122
+ ValueError, match="Relevance score for 'B001' must be numeric"
123
+ ):
124
+ load_eval_cases("test.json")
125
+
126
+ def test_relevance_score_as_int_accepted(self, tmp_path, monkeypatch):
127
+ """Integer relevance scores are accepted."""
128
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
129
+
130
+ data = [{"query": "test", "relevant_items": {"B001": 3}}]
131
+ (tmp_path / "test.json").write_text(json.dumps(data))
132
+
133
+ cases = load_eval_cases("test.json")
134
+
135
+ assert cases[0].relevant_items["B001"] == 3
136
+
137
+ def test_user_id_optional(self, tmp_path, monkeypatch):
138
+ """user_id field is optional."""
139
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
140
+
141
+ data = [{"query": "test", "relevant_items": {"B001": 1.0}}]
142
+ (tmp_path / "test.json").write_text(json.dumps(data))
143
+
144
+ cases = load_eval_cases("test.json")
145
+
146
+ assert cases[0].user_id is None
147
+
148
+ def test_user_id_preserved(self, tmp_path, monkeypatch):
149
+ """user_id field is preserved when present."""
150
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
151
+
152
+ data = [{"query": "test", "relevant_items": {"B001": 1.0}, "user_id": "U123"}]
153
+ (tmp_path / "test.json").write_text(json.dumps(data))
154
+
155
+ cases = load_eval_cases("test.json")
156
+
157
+ assert cases[0].user_id == "U123"
158
+
159
+ def test_extra_fields_ignored(self, tmp_path, monkeypatch):
160
+ """Extra fields (category, intent) are ignored without error."""
161
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
162
+
163
+ data = [
164
+ {
165
+ "query": "smart speaker",
166
+ "relevant_items": {"B001": 3.0},
167
+ "category": "echo_devices",
168
+ "intent": "feature_specific",
169
+ }
170
+ ]
171
+ (tmp_path / "test.json").write_text(json.dumps(data))
172
+
173
+ cases = load_eval_cases("test.json")
174
+
175
+ assert len(cases) == 1
176
+ assert cases[0].query == "smart speaker"
177
+
178
+ def test_relevant_set_works_after_load(self, tmp_path, monkeypatch):
179
+ """Loaded cases have working relevant_set property."""
180
+ monkeypatch.setattr("sage.data.eval.EVAL_DIR", tmp_path)
181
+
182
+ data = [
183
+ {"query": "test", "relevant_items": {"B001": 3.0, "B002": 0.0, "B003": 1.0}}
184
+ ]
185
+ (tmp_path / "test.json").write_text(json.dumps(data))
186
+
187
+ cases = load_eval_cases("test.json")
188
+
189
+ # relevant_set should only include items with score > 0
190
+ assert cases[0].relevant_set == {"B001", "B003"}