File size: 1,550 Bytes
052c644
 
 
1af43ae
052c644
 
 
 
1af43ae
052c644
 
 
 
1af43ae
052c644
 
 
 
1af43ae
052c644
 
 
 
 
1af43ae
052c644
 
 
 
1af43ae
052c644
 
1af43ae
052c644
 
1af43ae
052c644
 
 
 
 
 
1af43ae
052c644
 
 
 
1af43ae
052c644
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult


class DetectorOK:
    def detect(self, *a, **k):
        return []


class PlannerOK:
    def run(self, *a, **k):
        return StageResult(ok=True, data={"plan": "p"})


class GeneratorOK:
    def run(self, *a, **k):
        return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})


class SafetyOK:
    def run(self, *a, **k):
        sql = k.get("sql", "SELECT * FROM t")
        return StageResult(ok=True, data={"sql": sql})


class ExecOK:
    def run(self, *a, **k):
        return StageResult(ok=True, data={"rows": [{"x": 1}]})


class VerifierThenOK:
    """اولین بار fail، بعد از repair pass می‌کند."""

    def __init__(self):
        self.calls = 0

    def run(self, *, sql, exec_result):
        self.calls += 1
        if self.calls == 1:
            return StageResult(ok=False, error=["first verify fail"])
        return StageResult(ok=True, data={"verified": True})


class RepairOK:
    def run(self, *, sql, error_msg, schema_preview):
        return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})


def test_pipeline_repair_success_path():
    p = Pipeline(
        detector=DetectorOK(),
        planner=PlannerOK(),
        generator=GeneratorOK(),
        safety=SafetyOK(),
        executor=ExecOK(),
        verifier=VerifierThenOK(),
        repair=RepairOK(),
    )
    out = p.run(user_query="?", schema_preview="")
    assert out.ok
    assert out.verified
    assert not out.error