File size: 4,274 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
 
 
abd4352
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
 
abd4352
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
c5f9c5f
 
 
 
abd4352
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""
eval/run_eval.py
Runs the agent against eval/datasets/queries.json and prints a results table.
Usage: python eval/run_eval.py
"""

import json
import sys
import time
import uuid
from pathlib import Path

from dotenv import load_dotenv
load_dotenv()

sys.path.insert(0, str(Path(__file__).parent.parent))

from agent.graph import get_graph

DATASET = Path(__file__).parent / "datasets" / "queries.json"
PASS = "PASS"
FAIL = "FAIL"
WARN = "WARN"


def run_eval():
    cases = json.loads(DATASET.read_text())
    graph = get_graph()
    results = []

    print(f"\nRunning {len(cases)} eval cases...\n")

    for case in cases:
        state = {
            "session_id": str(uuid.uuid4()),
            "user_id": "eval",
            "user_query": case["query"],
            "connector_id": case["connector_id"],
            "intent": "",
            "query_plan": {},
            "relevant_tables": [],
            "schema_context": "",
            "memory_context": "",
            "conversation_history": [],
            "generated_code": "",
            "code_type": "sql",
            "sql_dialect": "postgres",
            "execution_result": None,
            "execution_error": None,
            "from_cache": False,
            "error_class": None,
            "correction_attempts": 0,
            "max_corrections": 3,
            "insight_text": "",
            "chart_spec": None,
            "anomalies": [],
            "history_id": None,
            "latency_ms": None,
            "stream_tokens": [],
        }

        t0 = time.time()
        try:
            result = graph.invoke(state)
            elapsed = int((time.time() - t0) * 1000)

            intent_ok = result.get("intent") == case["expected_intent"]
            has_result = bool(result.get("execution_result")) or result.get("intent") == "unsupported"
            no_error = not result.get("execution_error")

            insight = result.get("insight_text") or ""
            code = result.get("generated_code") or ""
            contains_ok = all(
                kw.lower() in insight.lower() or kw.lower() in code.lower()
                for kw in case.get("expected_contains", [])
            )

            passed = intent_ok and (has_result or case["expected_intent"] == "unsupported") and no_error and contains_ok
            status = PASS if passed else FAIL

            results.append({
                "id": case["id"],
                "query": case["query"][:55],
                "intent": result.get("intent"),
                "expected_intent": case["expected_intent"],
                "corrections": result.get("correction_attempts", 0),
                "anomalies": len(result.get("anomalies", [])),
                "latency_ms": elapsed,
                "passed": passed,
                "status": status,
                "exec_error": result.get("execution_error")
            })

        except Exception as exc:
            results.append({
                "id": case["id"],
                "query": case["query"][:55],
                "intent": "ERROR",
                "expected_intent": case["expected_intent"],
                "corrections": 0,
                "anomalies": 0,
                "latency_ms": int((time.time() - t0) * 1000),
                "passed": False,
                "status": FAIL,
                "error": str(exc),
            })

    # Print table
    print(f"{'ID':<5} {'Status':<4} {'Intent':<12} {'Fixes':<6} {'Warns':<6} {'ms':<7} Query")
    print("-" * 90)
    for r in results:
        print(
            f"{r['id']:<5} {r['status']:<4} {r['intent']:<12} {r['corrections']:<6} "
            f"{r.get('anomalies', 0):<6} {r['latency_ms']:<7} {r['query']}"
        )
        if r.get("error"):
            print(f"      -> SYS ERROR: {r['error']}")
        if r.get("exec_error"):
            print(f"      -> DB ERROR: {r['exec_error']}")

    passed = sum(1 for r in results if r["passed"])
    avg_lat = sum(r["latency_ms"] for r in results) // len(results)
    print(f"\n{'-' * 90}")
    print(f"Passed: {passed}/{len(results)} ({100 * passed // len(results)}%) | Avg Latency: {avg_lat}ms")
    return passed == len(results)



if __name__ == "__main__":
    ok = run_eval()
    sys.exit(0 if ok else 1)