Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
8e8639a
1
Parent(s):
0c2c0f1
Sync from GitHub main @ e5ca708f9310108380db0252e29edc2f832428bf
Browse files- adapters/llm/openai_provider.py +2 -2
- nl2sql/context_engineering/budgeter.py +43 -0
- nl2sql/context_engineering/engineer.py +52 -0
- nl2sql/context_engineering/parse.py +29 -0
- nl2sql/context_engineering/render.py +14 -0
- nl2sql/context_engineering/schema_pack.py +29 -0
- nl2sql/context_engineering/types.py +37 -0
- nl2sql/pipeline.py +43 -9
- nl2sql/pipeline_factory.py +11 -0
adapters/llm/openai_provider.py
CHANGED
|
@@ -134,7 +134,7 @@ Create a step-by-step plan to answer this question with SQL."""
|
|
| 134 |
user_query: The user's natural language question
|
| 135 |
schema_preview: Database schema information
|
| 136 |
plan_text: Query execution plan
|
| 137 |
-
clarify_answers: Optional additional
|
| 138 |
|
| 139 |
Returns:
|
| 140 |
Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
|
|
@@ -183,7 +183,7 @@ Wrong: {{"sql": "SELECT COUNT(singer.singer_id) AS total_singers FROM singer", "
|
|
| 183 |
Now generate the SQL for the given question:"""
|
| 184 |
|
| 185 |
if clarify_answers:
|
| 186 |
-
user_prompt += f"\n\nAdditional
|
| 187 |
|
| 188 |
completion = self._create_chat_completion(
|
| 189 |
model=self.model,
|
|
|
|
| 134 |
user_query: The user's natural language question
|
| 135 |
schema_preview: Database schema information
|
| 136 |
plan_text: Query execution plan
|
| 137 |
+
clarify_answers: Optional additional context_engineering
|
| 138 |
|
| 139 |
Returns:
|
| 140 |
Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
|
|
|
|
| 183 |
Now generate the SQL for the given question:"""
|
| 184 |
|
| 185 |
if clarify_answers:
|
| 186 |
+
user_prompt += f"\n\nAdditional context_engineering: {clarify_answers}"
|
| 187 |
|
| 188 |
completion = self._create_chat_completion(
|
| 189 |
model=self.model,
|
nl2sql/context_engineering/budgeter.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
from .types import SchemaPack, SchemaTable, ContextBudget
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def apply_budget(
|
| 8 |
+
pack: SchemaPack, budget: ContextBudget
|
| 9 |
+
) -> Tuple[SchemaPack, Optional[str]]:
|
| 10 |
+
reason: Optional[str] = None
|
| 11 |
+
|
| 12 |
+
table_names = sorted(pack.tables.keys())
|
| 13 |
+
if len(table_names) > budget.max_tables:
|
| 14 |
+
reason = f"tables_pruned_to_{budget.max_tables}"
|
| 15 |
+
table_names = table_names[: budget.max_tables]
|
| 16 |
+
|
| 17 |
+
new_tables = {}
|
| 18 |
+
for t in table_names:
|
| 19 |
+
tab = pack.tables[t]
|
| 20 |
+
cols = tab.columns[: budget.max_columns_per_table]
|
| 21 |
+
if len(tab.columns) > budget.max_columns_per_table:
|
| 22 |
+
reason = reason or "columns_trimmed_per_table"
|
| 23 |
+
new_tables[t] = SchemaTable(columns=cols, fks=tab.fks)
|
| 24 |
+
|
| 25 |
+
new_pack = SchemaPack(tables=new_tables, version=pack.version)
|
| 26 |
+
|
| 27 |
+
total_cols = sum(len(t.columns) for t in new_pack.tables.values())
|
| 28 |
+
if total_cols > budget.max_total_columns:
|
| 29 |
+
reason = reason or "columns_trimmed_total_cap"
|
| 30 |
+
remaining = budget.max_total_columns
|
| 31 |
+
capped = {}
|
| 32 |
+
for t in sorted(new_pack.tables.keys()):
|
| 33 |
+
tab = new_pack.tables[t]
|
| 34 |
+
if remaining <= 0:
|
| 35 |
+
capped[t] = SchemaTable(columns=[], fks=tab.fks)
|
| 36 |
+
continue
|
| 37 |
+
keep_n = min(len(tab.columns), remaining)
|
| 38 |
+
keep = tab.columns[:keep_n]
|
| 39 |
+
remaining -= len(keep)
|
| 40 |
+
capped[t] = SchemaTable(columns=keep, fks=tab.fks)
|
| 41 |
+
new_pack = SchemaPack(tables=capped, version=new_pack.version)
|
| 42 |
+
|
| 43 |
+
return new_pack, reason
|
nl2sql/context_engineering/engineer.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from .types import ContextBudget, ContextPacket, SchemaPack, SchemaTable
|
| 4 |
+
from .parse import parse_sqlite_schema_preview
|
| 5 |
+
from .budgeter import apply_budget
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DEFAULT_CONSTRAINTS = [
|
| 9 |
+
"SELECT_ONLY",
|
| 10 |
+
"NO_DDL_DML",
|
| 11 |
+
"NO_ATTACH_PRAGMA",
|
| 12 |
+
"SINGLE_STATEMENT",
|
| 13 |
+
"LIMIT_REQUIRED_IF_MISSING",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ContextEngineer:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
*,
|
| 21 |
+
budget: ContextBudget,
|
| 22 |
+
constraints: list[str] | None = None,
|
| 23 |
+
) -> None:
|
| 24 |
+
self.budget = budget
|
| 25 |
+
self.constraints = constraints or DEFAULT_CONSTRAINTS
|
| 26 |
+
|
| 27 |
+
def build(self, *, schema_preview: str) -> ContextPacket:
|
| 28 |
+
raw_tables = parse_sqlite_schema_preview(schema_preview)
|
| 29 |
+
|
| 30 |
+
tables_sorted = sorted(raw_tables.keys())
|
| 31 |
+
tables = {t: SchemaTable(columns=raw_tables[t], fks={}) for t in tables_sorted}
|
| 32 |
+
pack = SchemaPack(tables=tables, version="v1")
|
| 33 |
+
|
| 34 |
+
tables_before = len(pack.tables)
|
| 35 |
+
columns_before = sum(len(t.columns) for t in pack.tables.values())
|
| 36 |
+
|
| 37 |
+
packed, reason = apply_budget(pack, self.budget)
|
| 38 |
+
|
| 39 |
+
tables_after = len(packed.tables)
|
| 40 |
+
columns_after = sum(len(t.columns) for t in packed.tables.values())
|
| 41 |
+
|
| 42 |
+
return ContextPacket(
|
| 43 |
+
schema_pack=packed,
|
| 44 |
+
constraints=self.constraints,
|
| 45 |
+
db_hints=None,
|
| 46 |
+
budget=self.budget,
|
| 47 |
+
tables_before=tables_before,
|
| 48 |
+
columns_before=columns_before,
|
| 49 |
+
tables_after=tables_after,
|
| 50 |
+
columns_after=columns_after,
|
| 51 |
+
budget_reason=reason,
|
| 52 |
+
)
|
nl2sql/context_engineering/parse.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
_LINE_RE = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*\((.*)\)\s*$")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parse_sqlite_schema_preview(schema_preview: str) -> Dict[str, List[str]]:
|
| 11 |
+
raw_tables: Dict[str, List[str]] = {}
|
| 12 |
+
|
| 13 |
+
for line in (schema_preview or "").splitlines():
|
| 14 |
+
line = line.strip()
|
| 15 |
+
if not line:
|
| 16 |
+
continue
|
| 17 |
+
m = _LINE_RE.match(line)
|
| 18 |
+
if not m:
|
| 19 |
+
# ignore unknown line formats (future-proof)
|
| 20 |
+
continue
|
| 21 |
+
table = m.group(1)
|
| 22 |
+
cols_blob = m.group(2).strip()
|
| 23 |
+
cols = [c.strip() for c in cols_blob.split(",") if c.strip()]
|
| 24 |
+
# stable order: keep what service produced but also de-dup deterministically
|
| 25 |
+
cols = sorted(set(cols))
|
| 26 |
+
raw_tables[table] = cols
|
| 27 |
+
|
| 28 |
+
# stable order: sort keys by caller later
|
| 29 |
+
return raw_tables
|
nl2sql/context_engineering/render.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from .types import SchemaPack
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def render_schema_pack(pack: SchemaPack) -> str:
|
| 7 |
+
lines: list[str] = []
|
| 8 |
+
for table in sorted(pack.tables.keys()):
|
| 9 |
+
cols = pack.tables[table].columns
|
| 10 |
+
if cols:
|
| 11 |
+
lines.append(f"{table}({', '.join(cols)})")
|
| 12 |
+
else:
|
| 13 |
+
lines.append(f"{table}()")
|
| 14 |
+
return "\n".join(lines)
|
nl2sql/context_engineering/schema_pack.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
from .types import SchemaPack, SchemaTable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_schema_pack(
|
| 8 |
+
raw_tables: Dict[str, List[str]],
|
| 9 |
+
raw_fks: Dict[str, List[Tuple[str, str]]],
|
| 10 |
+
version: str = "v1",
|
| 11 |
+
) -> SchemaPack:
|
| 12 |
+
"""
|
| 13 |
+
raw_tables: {"orders": ["id", "user_id", ...], ...}
|
| 14 |
+
raw_fks: {"orders": [("user_id", "users.id"), ...], ...}
|
| 15 |
+
"""
|
| 16 |
+
tables_sorted = sorted(raw_tables.keys())
|
| 17 |
+
|
| 18 |
+
tables: Dict[str, SchemaTable] = {}
|
| 19 |
+
for t in tables_sorted:
|
| 20 |
+
cols = sorted(set(raw_tables.get(t, [])))
|
| 21 |
+
fks_list = raw_fks.get(t, [])
|
| 22 |
+
fks = {src: dst for (src, dst) in sorted(fks_list, key=lambda x: (x[0], x[1]))}
|
| 23 |
+
tables[t] = SchemaTable(columns=cols, fks=fks)
|
| 24 |
+
|
| 25 |
+
return SchemaPack(tables=tables, version=version)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def count_columns(pack: SchemaPack) -> int:
|
| 29 |
+
return sum(len(t.columns) for t in pack.tables.values())
|
nl2sql/context_engineering/types.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class SchemaTable:
|
| 9 |
+
columns: List[str]
|
| 10 |
+
fks: Dict[str, str] # kept for future; sqlite preview has none
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class SchemaPack:
|
| 15 |
+
tables: Dict[str, SchemaTable]
|
| 16 |
+
version: str = "v1"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class ContextBudget:
|
| 21 |
+
max_tables: int
|
| 22 |
+
max_columns_per_table: int
|
| 23 |
+
max_total_columns: int
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True)
|
| 27 |
+
class ContextPacket:
|
| 28 |
+
schema_pack: SchemaPack
|
| 29 |
+
constraints: List[str]
|
| 30 |
+
db_hints: Optional[dict]
|
| 31 |
+
budget: ContextBudget
|
| 32 |
+
|
| 33 |
+
tables_before: int
|
| 34 |
+
columns_before: int
|
| 35 |
+
tables_after: int
|
| 36 |
+
columns_after: int
|
| 37 |
+
budget_reason: Optional[str]
|
nl2sql/pipeline.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import traceback
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Dict, Any, Optional, List
|
|
@@ -16,6 +17,8 @@ from nl2sql.repair import Repair
|
|
| 16 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 17 |
from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
|
| 18 |
from nl2sql.errors.codes import ErrorCode
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@dataclass(frozen=True)
|
|
@@ -54,6 +57,7 @@ class Pipeline:
|
|
| 54 |
executor: Optional[Executor] = None,
|
| 55 |
verifier: Optional[Verifier] = None,
|
| 56 |
repair: Optional[Repair] = None,
|
|
|
|
| 57 |
):
|
| 58 |
self.detector = detector
|
| 59 |
self.planner = planner
|
|
@@ -64,6 +68,7 @@ class Pipeline:
|
|
| 64 |
self.repair = repair or NoOpRepair()
|
| 65 |
# If the verifier explicitly requires verification, enforce it in finalize.
|
| 66 |
self.require_verification = bool(getattr(self.verifier, "required", False))
|
|
|
|
| 67 |
|
| 68 |
# ---------------------------- helpers ----------------------------
|
| 69 |
@staticmethod
|
|
@@ -283,6 +288,13 @@ class Pipeline:
|
|
| 283 |
schema_preview = schema_preview or ""
|
| 284 |
clarify_answers = clarify_answers or {}
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
try:
|
| 287 |
# --- 1) detector ---
|
| 288 |
t0 = time.perf_counter()
|
|
@@ -314,14 +326,24 @@ class Pipeline:
|
|
| 314 |
|
| 315 |
# --- 2) planner ---
|
| 316 |
t0 = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
r_plan = self._run_with_repair(
|
| 318 |
"planner",
|
| 319 |
self.planner.run,
|
| 320 |
repair_input_builder=self._planner_repair_input_builder,
|
| 321 |
max_attempts=1,
|
| 322 |
-
|
| 323 |
-
traces=traces,
|
| 324 |
-
schema_preview=schema_preview,
|
| 325 |
)
|
| 326 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 327 |
stage_duration_ms.labels("planner").observe(dt)
|
|
@@ -345,16 +367,26 @@ class Pipeline:
|
|
| 345 |
|
| 346 |
# --- 3) generator ---
|
| 347 |
t0 = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
r_gen = self._run_with_repair(
|
| 349 |
"generator",
|
| 350 |
self.generator.run,
|
| 351 |
repair_input_builder=self._generator_repair_input_builder,
|
| 352 |
max_attempts=1,
|
| 353 |
-
|
| 354 |
-
schema_preview=schema_preview,
|
| 355 |
-
plan_text=(r_plan.data or {}).get("plan"),
|
| 356 |
-
clarify_answers=clarify_answers,
|
| 357 |
-
traces=traces,
|
| 358 |
)
|
| 359 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 360 |
stage_duration_ms.labels("generator").observe(dt)
|
|
@@ -447,7 +479,9 @@ class Pipeline:
|
|
| 447 |
if not getattr(r_exec, "trace", None):
|
| 448 |
_fallback_trace("executor", dt, r_exec.ok)
|
| 449 |
if not r_exec.ok and r_exec.error:
|
| 450 |
-
details.extend(
|
|
|
|
|
|
|
| 451 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 452 |
exec_result = dict(r_exec.data)
|
| 453 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import traceback
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from typing import Dict, Any, Optional, List
|
|
|
|
| 17 |
from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
|
| 18 |
from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
|
| 19 |
from nl2sql.errors.codes import ErrorCode
|
| 20 |
+
from nl2sql.context_engineering.render import render_schema_pack
|
| 21 |
+
from nl2sql.context_engineering.engineer import ContextEngineer
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass(frozen=True)
|
|
|
|
| 57 |
executor: Optional[Executor] = None,
|
| 58 |
verifier: Optional[Verifier] = None,
|
| 59 |
repair: Optional[Repair] = None,
|
| 60 |
+
context_engineer: ContextEngineer | None = None,
|
| 61 |
):
|
| 62 |
self.detector = detector
|
| 63 |
self.planner = planner
|
|
|
|
| 68 |
self.repair = repair or NoOpRepair()
|
| 69 |
# If the verifier explicitly requires verification, enforce it in finalize.
|
| 70 |
self.require_verification = bool(getattr(self.verifier, "required", False))
|
| 71 |
+
self.context_engineer = context_engineer
|
| 72 |
|
| 73 |
# ---------------------------- helpers ----------------------------
|
| 74 |
@staticmethod
|
|
|
|
| 288 |
schema_preview = schema_preview or ""
|
| 289 |
clarify_answers = clarify_answers or {}
|
| 290 |
|
| 291 |
+
# --- Context Engineering
|
| 292 |
+
schema_for_llm = schema_preview
|
| 293 |
+
|
| 294 |
+
if self.context_engineer is not None:
|
| 295 |
+
packet = self.context_engineer.build(schema_preview=schema_preview)
|
| 296 |
+
schema_for_llm = render_schema_pack(packet.schema_pack)
|
| 297 |
+
|
| 298 |
try:
|
| 299 |
# --- 1) detector ---
|
| 300 |
t0 = time.perf_counter()
|
|
|
|
| 326 |
|
| 327 |
# --- 2) planner ---
|
| 328 |
t0 = time.perf_counter()
|
| 329 |
+
|
| 330 |
+
planner_kwargs: Dict[str, Any] = {
|
| 331 |
+
"user_query": user_query,
|
| 332 |
+
"schema_preview": schema_for_llm,
|
| 333 |
+
"traces": traces,
|
| 334 |
+
}
|
| 335 |
+
try:
|
| 336 |
+
if "schema_pack" in inspect.signature(self.planner.run).parameters:
|
| 337 |
+
planner_kwargs["schema_pack"] = schema_for_llm
|
| 338 |
+
except (TypeError, ValueError):
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
r_plan = self._run_with_repair(
|
| 342 |
"planner",
|
| 343 |
self.planner.run,
|
| 344 |
repair_input_builder=self._planner_repair_input_builder,
|
| 345 |
max_attempts=1,
|
| 346 |
+
**planner_kwargs,
|
|
|
|
|
|
|
| 347 |
)
|
| 348 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 349 |
stage_duration_ms.labels("planner").observe(dt)
|
|
|
|
| 367 |
|
| 368 |
# --- 3) generator ---
|
| 369 |
t0 = time.perf_counter()
|
| 370 |
+
|
| 371 |
+
gen_kwargs: Dict[str, Any] = {
|
| 372 |
+
"user_query": user_query,
|
| 373 |
+
"schema_preview": schema_for_llm,
|
| 374 |
+
"plan_text": (r_plan.data or {}).get("plan"),
|
| 375 |
+
"clarify_answers": clarify_answers,
|
| 376 |
+
"traces": traces,
|
| 377 |
+
}
|
| 378 |
+
try:
|
| 379 |
+
if "schema_pack" in inspect.signature(self.generator.run).parameters:
|
| 380 |
+
gen_kwargs["schema_pack"] = schema_for_llm
|
| 381 |
+
except (TypeError, ValueError):
|
| 382 |
+
pass
|
| 383 |
+
|
| 384 |
r_gen = self._run_with_repair(
|
| 385 |
"generator",
|
| 386 |
self.generator.run,
|
| 387 |
repair_input_builder=self._generator_repair_input_builder,
|
| 388 |
max_attempts=1,
|
| 389 |
+
**gen_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
)
|
| 391 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 392 |
stage_duration_ms.labels("generator").observe(dt)
|
|
|
|
| 479 |
if not getattr(r_exec, "trace", None):
|
| 480 |
_fallback_trace("executor", dt, r_exec.ok)
|
| 481 |
if not r_exec.ok and r_exec.error:
|
| 482 |
+
details.extend(
|
| 483 |
+
r_exec.error
|
| 484 |
+
) # soft: keep for repair/verifier context_engineering
|
| 485 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 486 |
exec_result = dict(r_exec.data)
|
| 487 |
|
nl2sql/pipeline_factory.py
CHANGED
|
@@ -29,6 +29,8 @@ from nl2sql.generator import Generator
|
|
| 29 |
from nl2sql.executor import Executor
|
| 30 |
from nl2sql.verifier import Verifier
|
| 31 |
from nl2sql.repair import Repair
|
|
|
|
|
|
|
| 32 |
|
| 33 |
from adapters.db.base import DBAdapter
|
| 34 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
|
@@ -195,6 +197,14 @@ def pipeline_from_config(path: str) -> Pipeline:
|
|
| 195 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 196 |
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
return Pipeline(
|
| 199 |
detector=detector,
|
| 200 |
planner=planner,
|
|
@@ -203,6 +213,7 @@ def pipeline_from_config(path: str) -> Pipeline:
|
|
| 203 |
executor=executor,
|
| 204 |
verifier=verifier,
|
| 205 |
repair=repair,
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
|
|
|
|
| 29 |
from nl2sql.executor import Executor
|
| 30 |
from nl2sql.verifier import Verifier
|
| 31 |
from nl2sql.repair import Repair
|
| 32 |
+
from nl2sql.context_engineering.engineer import ContextEngineer
|
| 33 |
+
from nl2sql.context_engineering.types import ContextBudget
|
| 34 |
|
| 35 |
from adapters.db.base import DBAdapter
|
| 36 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
|
|
|
| 197 |
verifier = VERIFIERS[cfg.get("verifier", "basic")]()
|
| 198 |
repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
|
| 199 |
|
| 200 |
+
context_engineer = ContextEngineer(
|
| 201 |
+
budget=ContextBudget(
|
| 202 |
+
max_tables=25,
|
| 203 |
+
max_columns_per_table=25,
|
| 204 |
+
max_total_columns=400,
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
return Pipeline(
|
| 209 |
detector=detector,
|
| 210 |
planner=planner,
|
|
|
|
| 213 |
executor=executor,
|
| 214 |
verifier=verifier,
|
| 215 |
repair=repair,
|
| 216 |
+
context_engineer=context_engineer,
|
| 217 |
)
|
| 218 |
|
| 219 |
|