Spaces:
Sleeping
Sleeping
File size: 1,550 Bytes
052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 1af43ae 052c644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult
class DetectorOK:
def detect(self, *a, **k):
return []
class PlannerOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"plan": "p"})
class GeneratorOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
class SafetyOK:
def run(self, *a, **k):
sql = k.get("sql", "SELECT * FROM t")
return StageResult(ok=True, data={"sql": sql})
class ExecOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"rows": [{"x": 1}]})
class VerifierThenOK:
"""اولین بار fail، بعد از repair pass میکند."""
def __init__(self):
self.calls = 0
def run(self, *, sql, exec_result):
self.calls += 1
if self.calls == 1:
return StageResult(ok=False, error=["first verify fail"])
return StageResult(ok=True, data={"verified": True})
class RepairOK:
def run(self, *, sql, error_msg, schema_preview):
return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
def test_pipeline_repair_success_path():
p = Pipeline(
detector=DetectorOK(),
planner=PlannerOK(),
generator=GeneratorOK(),
safety=SafetyOK(),
executor=ExecOK(),
verifier=VerifierThenOK(),
repair=RepairOK(),
)
out = p.run(user_query="?", schema_preview="")
assert out.ok
assert out.verified
assert not out.error
|