Melika Kheirieh commited on
Commit
d347376
·
1 Parent(s): ebc7457

test(benchmarks): add black-box tests for evaluate_spider outputs and trace normalization

Browse files
Files changed (1) hide show
  1. tests/test_evaluate_spider.py +135 -0
tests/test_evaluate_spider.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import types
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+
10
+ @pytest.fixture()
11
+ def temp_cwd(tmp_path, monkeypatch):
12
+ """Isolate working directory to a temp folder so outputs don't leak."""
13
+ monkeypatch.chdir(tmp_path)
14
+ return tmp_path
15
+
16
+
17
+ def _install_fake_router_module(monkeypatch):
18
+ """
19
+ Install a fake 'app.routers.nl2sql' module into sys.modules
20
+ BEFORE importing evaluate_spider, so its top-level imports resolve.
21
+ """
22
+ # Create package hierarchy: app, app.routers, app.routers.nl2sql
23
+ app_mod = types.ModuleType("app")
24
+ routers_mod = types.ModuleType("app.routers")
25
+ nl2sql_mod = types.ModuleType("app.routers.nl2sql")
26
+
27
+ class _FakeExec:
28
+ def derive_schema_preview(self):
29
+ return "TABLE users(id INT);"
30
+
31
+ class _FakeResult:
32
+ def __init__(self, ok=True):
33
+ self.ok = ok
34
+ # mix dicts/objects to exercise _to_stage_list normalization
35
+ self.trace = [
36
+ {"stage": "planner", "duration_ms": 11},
37
+ types.SimpleNamespace(stage="generator", duration_ms=23),
38
+ {"stage": "safety", "duration_ms": 5},
39
+ ]
40
+
41
+ class _FakePipeline:
42
+ def __init__(self):
43
+ self.executor = _FakeExec()
44
+
45
+ def run(self, *, user_query: str, schema_preview: str = ""):
46
+ return _FakeResult(ok=True)
47
+
48
+ # exported symbols used by evaluate_spider
49
+ nl2sql_mod._pipeline = _FakePipeline()
50
+ nl2sql_mod._build_pipeline = lambda adapter: _FakePipeline()
51
+ nl2sql_mod._select_adapter = lambda dbid: object()
52
+
53
+ # register in sys.modules (package chain)
54
+ sys.modules["app"] = app_mod
55
+ sys.modules["app.routers"] = routers_mod
56
+ sys.modules["app.routers.nl2sql"] = nl2sql_mod
57
+
58
+
59
+ def test_evaluate_spider_writes_outputs(temp_cwd, monkeypatch):
60
+ # 1) install fake router module BEFORE import
61
+ _install_fake_router_module(monkeypatch)
62
+
63
+ # 2) import module under test (now its top-level imports succeed)
64
+ import benchmarks.evaluate_spider as mod
65
+
66
+ # 3) shrink dataset for speed and redirect outputs into tmp dir
67
+ monkeypatch.setattr(mod, "DATASET", ["q1", "q2"], raising=True)
68
+ out_root = Path("benchmarks") / "results"
69
+ monkeypatch.setattr(mod, "RESULT_ROOT", out_root, raising=True)
70
+ # Recompute RESULT_DIR to reflect new root (keep its naming scheme)
71
+ run_dir = out_root / "test-run"
72
+ monkeypatch.setattr(mod, "RESULT_DIR", run_dir, raising=True)
73
+
74
+ # 4) execute main
75
+ mod.main()
76
+
77
+ # 5) verify files exist
78
+ jsonl_path = run_dir / "spider_eval.jsonl"
79
+ summary_path = run_dir / "metrics_summary.json"
80
+ csv_path = run_dir / "results.csv"
81
+
82
+ assert jsonl_path.exists(), "jsonl not written"
83
+ assert summary_path.exists(), "summary not written"
84
+ assert csv_path.exists(), "csv not written"
85
+
86
+ # 6) validate JSONL (2 lines, keys present, normalized trace)
87
+ lines = jsonl_path.read_text(encoding="utf-8").strip().splitlines()
88
+ assert len(lines) == 2
89
+ rec0 = json.loads(lines[0])
90
+ assert set(rec0.keys()) >= {"query", "ok", "latency_ms", "trace", "error"}
91
+ assert isinstance(rec0["ok"], bool)
92
+ assert isinstance(rec0["latency_ms"], int)
93
+ assert isinstance(rec0["trace"], list)
94
+ assert all("stage" in t and "ms" in t for t in rec0["trace"])
95
+
96
+ # 7) validate summary.json
97
+ summary = json.loads(summary_path.read_text(encoding="utf-8"))
98
+ assert summary["queries_total"] == 2
99
+ assert 0.0 <= summary["success_rate"] <= 1.0
100
+ assert isinstance(summary["avg_latency_ms"], (int, float))
101
+ assert summary["pipeline_source"] in {"default", "adapter"} # per code path
102
+
103
+ # 8) validate CSV
104
+ with csv_path.open(newline="", encoding="utf-8") as f:
105
+ rows = list(csv.DictReader(f))
106
+ assert len(rows) == 2
107
+ assert set(rows[0].keys()) == {"query", "ok", "latency_ms"}
108
+ assert rows[0]["ok"] in {"✅", "❌"}
109
+ assert int(rows[0]["latency_ms"]) >= 0
110
+
111
+
112
+ def test_to_stage_list_normalizes_mixed_items(temp_cwd, monkeypatch):
113
+ _install_fake_router_module(monkeypatch)
114
+ import benchmarks.evaluate_spider as mod
115
+
116
+ mixed = [
117
+ {"stage": "planner", "duration_ms": 10},
118
+ types.SimpleNamespace(stage="generator", duration_ms=20),
119
+ {"stage": "safety", "duration_ms": "7"},
120
+ ]
121
+ out = mod._to_stage_list(mixed)
122
+ assert out == [
123
+ {"stage": "planner", "ms": 10},
124
+ {"stage": "generator", "ms": 20},
125
+ {"stage": "safety", "ms": 7},
126
+ ]
127
+
128
+
129
+ def test_int_ms_returns_int(temp_cwd, monkeypatch):
130
+ _install_fake_router_module(monkeypatch)
131
+ import benchmarks.evaluate_spider as mod
132
+
133
+ # use a small synthetic duration to assert type not magnitude
134
+ t0 = 0.0
135
+ assert isinstance(mod._int_ms(t0), int)