github-actions[bot] commited on
Commit
8e8639a
·
1 Parent(s): 0c2c0f1

Sync from GitHub main @ e5ca708f9310108380db0252e29edc2f832428bf

Browse files
adapters/llm/openai_provider.py CHANGED
@@ -134,7 +134,7 @@ Create a step-by-step plan to answer this question with SQL."""
134
  user_query: The user's natural language question
135
  schema_preview: Database schema information
136
  plan_text: Query execution plan
137
- clarify_answers: Optional additional context
138
 
139
  Returns:
140
  Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
@@ -183,7 +183,7 @@ Wrong: {{"sql": "SELECT COUNT(singer.singer_id) AS total_singers FROM singer", "
183
  Now generate the SQL for the given question:"""
184
 
185
  if clarify_answers:
186
- user_prompt += f"\n\nAdditional context: {clarify_answers}"
187
 
188
  completion = self._create_chat_completion(
189
  model=self.model,
 
134
  user_query: The user's natural language question
135
  schema_preview: Database schema information
136
  plan_text: Query execution plan
137
+ clarify_answers: Optional additional context_engineering
138
 
139
  Returns:
140
  Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
 
183
  Now generate the SQL for the given question:"""
184
 
185
  if clarify_answers:
186
+ user_prompt += f"\n\nAdditional context_engineering: {clarify_answers}"
187
 
188
  completion = self._create_chat_completion(
189
  model=self.model,
nl2sql/context_engineering/budgeter.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Tuple
4
+ from .types import SchemaPack, SchemaTable, ContextBudget
5
+
6
+
7
+ def apply_budget(
8
+ pack: SchemaPack, budget: ContextBudget
9
+ ) -> Tuple[SchemaPack, Optional[str]]:
10
+ reason: Optional[str] = None
11
+
12
+ table_names = sorted(pack.tables.keys())
13
+ if len(table_names) > budget.max_tables:
14
+ reason = f"tables_pruned_to_{budget.max_tables}"
15
+ table_names = table_names[: budget.max_tables]
16
+
17
+ new_tables = {}
18
+ for t in table_names:
19
+ tab = pack.tables[t]
20
+ cols = tab.columns[: budget.max_columns_per_table]
21
+ if len(tab.columns) > budget.max_columns_per_table:
22
+ reason = reason or "columns_trimmed_per_table"
23
+ new_tables[t] = SchemaTable(columns=cols, fks=tab.fks)
24
+
25
+ new_pack = SchemaPack(tables=new_tables, version=pack.version)
26
+
27
+ total_cols = sum(len(t.columns) for t in new_pack.tables.values())
28
+ if total_cols > budget.max_total_columns:
29
+ reason = reason or "columns_trimmed_total_cap"
30
+ remaining = budget.max_total_columns
31
+ capped = {}
32
+ for t in sorted(new_pack.tables.keys()):
33
+ tab = new_pack.tables[t]
34
+ if remaining <= 0:
35
+ capped[t] = SchemaTable(columns=[], fks=tab.fks)
36
+ continue
37
+ keep_n = min(len(tab.columns), remaining)
38
+ keep = tab.columns[:keep_n]
39
+ remaining -= len(keep)
40
+ capped[t] = SchemaTable(columns=keep, fks=tab.fks)
41
+ new_pack = SchemaPack(tables=capped, version=new_pack.version)
42
+
43
+ return new_pack, reason
nl2sql/context_engineering/engineer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from .types import ContextBudget, ContextPacket, SchemaPack, SchemaTable
4
+ from .parse import parse_sqlite_schema_preview
5
+ from .budgeter import apply_budget
6
+
7
+
8
+ DEFAULT_CONSTRAINTS = [
9
+ "SELECT_ONLY",
10
+ "NO_DDL_DML",
11
+ "NO_ATTACH_PRAGMA",
12
+ "SINGLE_STATEMENT",
13
+ "LIMIT_REQUIRED_IF_MISSING",
14
+ ]
15
+
16
+
17
+ class ContextEngineer:
18
+ def __init__(
19
+ self,
20
+ *,
21
+ budget: ContextBudget,
22
+ constraints: list[str] | None = None,
23
+ ) -> None:
24
+ self.budget = budget
25
+ self.constraints = constraints or DEFAULT_CONSTRAINTS
26
+
27
+ def build(self, *, schema_preview: str) -> ContextPacket:
28
+ raw_tables = parse_sqlite_schema_preview(schema_preview)
29
+
30
+ tables_sorted = sorted(raw_tables.keys())
31
+ tables = {t: SchemaTable(columns=raw_tables[t], fks={}) for t in tables_sorted}
32
+ pack = SchemaPack(tables=tables, version="v1")
33
+
34
+ tables_before = len(pack.tables)
35
+ columns_before = sum(len(t.columns) for t in pack.tables.values())
36
+
37
+ packed, reason = apply_budget(pack, self.budget)
38
+
39
+ tables_after = len(packed.tables)
40
+ columns_after = sum(len(t.columns) for t in packed.tables.values())
41
+
42
+ return ContextPacket(
43
+ schema_pack=packed,
44
+ constraints=self.constraints,
45
+ db_hints=None,
46
+ budget=self.budget,
47
+ tables_before=tables_before,
48
+ columns_before=columns_before,
49
+ tables_after=tables_after,
50
+ columns_after=columns_after,
51
+ budget_reason=reason,
52
+ )
nl2sql/context_engineering/parse.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+ import re
5
+
6
+
7
+ _LINE_RE = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*\((.*)\)\s*$")
8
+
9
+
10
+ def parse_sqlite_schema_preview(schema_preview: str) -> Dict[str, List[str]]:
11
+ raw_tables: Dict[str, List[str]] = {}
12
+
13
+ for line in (schema_preview or "").splitlines():
14
+ line = line.strip()
15
+ if not line:
16
+ continue
17
+ m = _LINE_RE.match(line)
18
+ if not m:
19
+ # ignore unknown line formats (future-proof)
20
+ continue
21
+ table = m.group(1)
22
+ cols_blob = m.group(2).strip()
23
+ cols = [c.strip() for c in cols_blob.split(",") if c.strip()]
24
+ # stable order: keep what service produced but also de-dup deterministically
25
+ cols = sorted(set(cols))
26
+ raw_tables[table] = cols
27
+
28
+ # stable order: sort keys by caller later
29
+ return raw_tables
nl2sql/context_engineering/render.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from .types import SchemaPack
4
+
5
+
6
+ def render_schema_pack(pack: SchemaPack) -> str:
7
+ lines: list[str] = []
8
+ for table in sorted(pack.tables.keys()):
9
+ cols = pack.tables[table].columns
10
+ if cols:
11
+ lines.append(f"{table}({', '.join(cols)})")
12
+ else:
13
+ lines.append(f"{table}()")
14
+ return "\n".join(lines)
nl2sql/context_engineering/schema_pack.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Tuple
4
+ from .types import SchemaPack, SchemaTable
5
+
6
+
7
+ def build_schema_pack(
8
+ raw_tables: Dict[str, List[str]],
9
+ raw_fks: Dict[str, List[Tuple[str, str]]],
10
+ version: str = "v1",
11
+ ) -> SchemaPack:
12
+ """
13
+ raw_tables: {"orders": ["id", "user_id", ...], ...}
14
+ raw_fks: {"orders": [("user_id", "users.id"), ...], ...}
15
+ """
16
+ tables_sorted = sorted(raw_tables.keys())
17
+
18
+ tables: Dict[str, SchemaTable] = {}
19
+ for t in tables_sorted:
20
+ cols = sorted(set(raw_tables.get(t, [])))
21
+ fks_list = raw_fks.get(t, [])
22
+ fks = {src: dst for (src, dst) in sorted(fks_list, key=lambda x: (x[0], x[1]))}
23
+ tables[t] = SchemaTable(columns=cols, fks=fks)
24
+
25
+ return SchemaPack(tables=tables, version=version)
26
+
27
+
28
+ def count_columns(pack: SchemaPack) -> int:
29
+ return sum(len(t.columns) for t in pack.tables.values())
nl2sql/context_engineering/types.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class SchemaTable:
9
+ columns: List[str]
10
+ fks: Dict[str, str] # kept for future; sqlite preview has none
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class SchemaPack:
15
+ tables: Dict[str, SchemaTable]
16
+ version: str = "v1"
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ContextBudget:
21
+ max_tables: int
22
+ max_columns_per_table: int
23
+ max_total_columns: int
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class ContextPacket:
28
+ schema_pack: SchemaPack
29
+ constraints: List[str]
30
+ db_hints: Optional[dict]
31
+ budget: ContextBudget
32
+
33
+ tables_before: int
34
+ columns_before: int
35
+ tables_after: int
36
+ columns_after: int
37
+ budget_reason: Optional[str]
nl2sql/pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  from __future__ import annotations
 
2
  import traceback
3
  from dataclasses import dataclass
4
  from typing import Dict, Any, Optional, List
@@ -16,6 +17,8 @@ from nl2sql.repair import Repair
16
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
17
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
18
  from nl2sql.errors.codes import ErrorCode
 
 
19
 
20
 
21
  @dataclass(frozen=True)
@@ -54,6 +57,7 @@ class Pipeline:
54
  executor: Optional[Executor] = None,
55
  verifier: Optional[Verifier] = None,
56
  repair: Optional[Repair] = None,
 
57
  ):
58
  self.detector = detector
59
  self.planner = planner
@@ -64,6 +68,7 @@ class Pipeline:
64
  self.repair = repair or NoOpRepair()
65
  # If the verifier explicitly requires verification, enforce it in finalize.
66
  self.require_verification = bool(getattr(self.verifier, "required", False))
 
67
 
68
  # ---------------------------- helpers ----------------------------
69
  @staticmethod
@@ -283,6 +288,13 @@ class Pipeline:
283
  schema_preview = schema_preview or ""
284
  clarify_answers = clarify_answers or {}
285
 
 
 
 
 
 
 
 
286
  try:
287
  # --- 1) detector ---
288
  t0 = time.perf_counter()
@@ -314,14 +326,24 @@ class Pipeline:
314
 
315
  # --- 2) planner ---
316
  t0 = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
317
  r_plan = self._run_with_repair(
318
  "planner",
319
  self.planner.run,
320
  repair_input_builder=self._planner_repair_input_builder,
321
  max_attempts=1,
322
- user_query=user_query,
323
- traces=traces,
324
- schema_preview=schema_preview,
325
  )
326
  dt = (time.perf_counter() - t0) * 1000.0
327
  stage_duration_ms.labels("planner").observe(dt)
@@ -345,16 +367,26 @@ class Pipeline:
345
 
346
  # --- 3) generator ---
347
  t0 = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  r_gen = self._run_with_repair(
349
  "generator",
350
  self.generator.run,
351
  repair_input_builder=self._generator_repair_input_builder,
352
  max_attempts=1,
353
- user_query=user_query,
354
- schema_preview=schema_preview,
355
- plan_text=(r_plan.data or {}).get("plan"),
356
- clarify_answers=clarify_answers,
357
- traces=traces,
358
  )
359
  dt = (time.perf_counter() - t0) * 1000.0
360
  stage_duration_ms.labels("generator").observe(dt)
@@ -447,7 +479,9 @@ class Pipeline:
447
  if not getattr(r_exec, "trace", None):
448
  _fallback_trace("executor", dt, r_exec.ok)
449
  if not r_exec.ok and r_exec.error:
450
- details.extend(r_exec.error) # soft: keep for repair/verifier context
 
 
451
  if r_exec.ok and isinstance(r_exec.data, dict):
452
  exec_result = dict(r_exec.data)
453
 
 
1
  from __future__ import annotations
2
+
3
  import traceback
4
  from dataclasses import dataclass
5
  from typing import Dict, Any, Optional, List
 
17
  from nl2sql.stubs import NoOpExecutor, NoOpRepair, NoOpVerifier
18
  from nl2sql.metrics import stage_duration_ms, pipeline_runs_total, repair_attempts_total
19
  from nl2sql.errors.codes import ErrorCode
20
+ from nl2sql.context_engineering.render import render_schema_pack
21
+ from nl2sql.context_engineering.engineer import ContextEngineer
22
 
23
 
24
  @dataclass(frozen=True)
 
57
  executor: Optional[Executor] = None,
58
  verifier: Optional[Verifier] = None,
59
  repair: Optional[Repair] = None,
60
+ context_engineer: ContextEngineer | None = None,
61
  ):
62
  self.detector = detector
63
  self.planner = planner
 
68
  self.repair = repair or NoOpRepair()
69
  # If the verifier explicitly requires verification, enforce it in finalize.
70
  self.require_verification = bool(getattr(self.verifier, "required", False))
71
+ self.context_engineer = context_engineer
72
 
73
  # ---------------------------- helpers ----------------------------
74
  @staticmethod
 
288
  schema_preview = schema_preview or ""
289
  clarify_answers = clarify_answers or {}
290
 
291
+ # --- Context Engineering
292
+ schema_for_llm = schema_preview
293
+
294
+ if self.context_engineer is not None:
295
+ packet = self.context_engineer.build(schema_preview=schema_preview)
296
+ schema_for_llm = render_schema_pack(packet.schema_pack)
297
+
298
  try:
299
  # --- 1) detector ---
300
  t0 = time.perf_counter()
 
326
 
327
  # --- 2) planner ---
328
  t0 = time.perf_counter()
329
+
330
+ planner_kwargs: Dict[str, Any] = {
331
+ "user_query": user_query,
332
+ "schema_preview": schema_for_llm,
333
+ "traces": traces,
334
+ }
335
+ try:
336
+ if "schema_pack" in inspect.signature(self.planner.run).parameters:
337
+ planner_kwargs["schema_pack"] = schema_for_llm
338
+ except (TypeError, ValueError):
339
+ pass
340
+
341
  r_plan = self._run_with_repair(
342
  "planner",
343
  self.planner.run,
344
  repair_input_builder=self._planner_repair_input_builder,
345
  max_attempts=1,
346
+ **planner_kwargs,
 
 
347
  )
348
  dt = (time.perf_counter() - t0) * 1000.0
349
  stage_duration_ms.labels("planner").observe(dt)
 
367
 
368
  # --- 3) generator ---
369
  t0 = time.perf_counter()
370
+
371
+ gen_kwargs: Dict[str, Any] = {
372
+ "user_query": user_query,
373
+ "schema_preview": schema_for_llm,
374
+ "plan_text": (r_plan.data or {}).get("plan"),
375
+ "clarify_answers": clarify_answers,
376
+ "traces": traces,
377
+ }
378
+ try:
379
+ if "schema_pack" in inspect.signature(self.generator.run).parameters:
380
+ gen_kwargs["schema_pack"] = schema_for_llm
381
+ except (TypeError, ValueError):
382
+ pass
383
+
384
  r_gen = self._run_with_repair(
385
  "generator",
386
  self.generator.run,
387
  repair_input_builder=self._generator_repair_input_builder,
388
  max_attempts=1,
389
+ **gen_kwargs,
 
 
 
 
390
  )
391
  dt = (time.perf_counter() - t0) * 1000.0
392
  stage_duration_ms.labels("generator").observe(dt)
 
479
  if not getattr(r_exec, "trace", None):
480
  _fallback_trace("executor", dt, r_exec.ok)
481
  if not r_exec.ok and r_exec.error:
482
+ details.extend(
483
+ r_exec.error
484
+ ) # soft: keep for repair/verifier context_engineering
485
  if r_exec.ok and isinstance(r_exec.data, dict):
486
  exec_result = dict(r_exec.data)
487
 
nl2sql/pipeline_factory.py CHANGED
@@ -29,6 +29,8 @@ from nl2sql.generator import Generator
29
  from nl2sql.executor import Executor
30
  from nl2sql.verifier import Verifier
31
  from nl2sql.repair import Repair
 
 
32
 
33
  from adapters.db.base import DBAdapter
34
  from adapters.db.sqlite_adapter import SQLiteAdapter
@@ -195,6 +197,14 @@ def pipeline_from_config(path: str) -> Pipeline:
195
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
196
  repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
197
 
 
 
 
 
 
 
 
 
198
  return Pipeline(
199
  detector=detector,
200
  planner=planner,
@@ -203,6 +213,7 @@ def pipeline_from_config(path: str) -> Pipeline:
203
  executor=executor,
204
  verifier=verifier,
205
  repair=repair,
 
206
  )
207
 
208
 
 
29
  from nl2sql.executor import Executor
30
  from nl2sql.verifier import Verifier
31
  from nl2sql.repair import Repair
32
+ from nl2sql.context_engineering.engineer import ContextEngineer
33
+ from nl2sql.context_engineering.types import ContextBudget
34
 
35
  from adapters.db.base import DBAdapter
36
  from adapters.db.sqlite_adapter import SQLiteAdapter
 
197
  verifier = VERIFIERS[cfg.get("verifier", "basic")]()
198
  repair = REPAIRS[cfg.get("repair", "default")](llm=llm)
199
 
200
+ context_engineer = ContextEngineer(
201
+ budget=ContextBudget(
202
+ max_tables=25,
203
+ max_columns_per_table=25,
204
+ max_total_columns=400,
205
+ )
206
+ )
207
+
208
  return Pipeline(
209
  detector=detector,
210
  planner=planner,
 
213
  executor=executor,
214
  verifier=verifier,
215
  repair=repair,
216
+ context_engineer=context_engineer,
217
  )
218
 
219