nl2sql-copilot / tests /test_pipeline_extra.py
Melika Kheirieh
style: format code with ruff
1af43ae
raw
history blame
1.55 kB
from nl2sql.pipeline import Pipeline
from nl2sql.types import StageResult
class DetectorOK:
def detect(self, *a, **k):
return []
class PlannerOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"plan": "p"})
class GeneratorOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"sql": "SELECT * FROM t", "rationale": "ok"})
class SafetyOK:
def run(self, *a, **k):
sql = k.get("sql", "SELECT * FROM t")
return StageResult(ok=True, data={"sql": sql})
class ExecOK:
def run(self, *a, **k):
return StageResult(ok=True, data={"rows": [{"x": 1}]})
class VerifierThenOK:
"""اولین بار fail، بعد از repair pass می‌کند."""
def __init__(self):
self.calls = 0
def run(self, *, sql, exec_result):
self.calls += 1
if self.calls == 1:
return StageResult(ok=False, error=["first verify fail"])
return StageResult(ok=True, data={"verified": True})
class RepairOK:
def run(self, *, sql, error_msg, schema_preview):
return StageResult(ok=True, data={"sql": "SELECT * FROM t LIMIT 1"})
def test_pipeline_repair_success_path():
p = Pipeline(
detector=DetectorOK(),
planner=PlannerOK(),
generator=GeneratorOK(),
safety=SafetyOK(),
executor=ExecOK(),
verifier=VerifierThenOK(),
repair=RepairOK(),
)
out = p.run(user_query="?", schema_preview="")
assert out.ok
assert out.verified
assert not out.error