github-actions[bot] commited on
Commit
4e73462
·
1 Parent(s): 8e8639a

Sync from GitHub main @ 517739c210f47d8dcf880b0b6b7501a464d6ef4f

Browse files
adapters/llm/base.py CHANGED
@@ -1,14 +1,19 @@
1
  from __future__ import annotations
2
- from typing import Tuple, Dict, Any, Protocol
 
3
 
4
 
5
  class LLMProvider(Protocol):
6
  PROVIDER_ID: str
7
 
8
  def plan(
9
- self, *, user_query: str, schema_preview: str
10
- ) -> Tuple[str, int, int, float]:
11
- """Return (plan_text, token_in, token_out, cost_usd)."""
 
 
 
 
12
 
13
  def generate_sql(
14
  self,
@@ -16,6 +21,7 @@ class LLMProvider(Protocol):
16
  user_query: str,
17
  schema_preview: str,
18
  plan_text: str,
 
19
  clarify_answers: Dict[str, Any] | None = None,
20
  ) -> Tuple[str, str, int, int, float]:
21
  """Return (sql, rationale, token_in, token_out, cost_usd)."""
 
1
  from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Protocol, Tuple
4
 
5
 
6
  class LLMProvider(Protocol):
7
  PROVIDER_ID: str
8
 
9
  def plan(
10
+ self,
11
+ *,
12
+ user_query: str,
13
+ schema_preview: str,
14
+ constraints: List[str] | None = None,
15
+ ) -> Tuple[str, List[str], int, int, float]:
16
+ """Return (plan_text, used_tables, token_in, token_out, cost_usd)."""
17
 
18
  def generate_sql(
19
  self,
 
21
  user_query: str,
22
  schema_preview: str,
23
  plan_text: str,
24
+ constraints: List[str] | None = None,
25
  clarify_answers: Dict[str, Any] | None = None,
26
  ) -> Tuple[str, str, int, int, float]:
27
  """Return (sql, rationale, token_in, token_out, cost_usd)."""
adapters/llm/openai_provider.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import json
4
  import os
5
  import re
6
- from typing import Any, List, Tuple
7
 
8
  from adapters.llm.base import LLMProvider
9
  from openai import OpenAI
@@ -35,17 +35,15 @@ def _resolve_api_config() -> tuple[str, str, str]:
35
 
36
 
37
  class OpenAIProvider(LLMProvider):
38
- """OpenAI LLM provider implementation."""
39
 
40
- PROVIDER_ID = "openai"
41
-
42
- def get_last_usage(self) -> dict[str, Any]:
43
- """Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
44
- return dict(self._last_usage)
45
 
46
- def _create_chat_completion(self, **kwargs):
47
- """OpenAI SDK seam for stable unit testing."""
48
- return self.client.chat.completions.create(**kwargs)
49
 
50
  def __init__(self) -> None:
51
  """Initialize OpenAI client with config from environment."""
@@ -54,21 +52,114 @@ class OpenAIProvider(LLMProvider):
54
  os.environ["OPENAI_BASE_URL"] = base_url
55
  self.client = OpenAI(timeout=120.0)
56
  self.model = model
57
- # last call usage/metadata for tracing
58
  self._last_usage: dict[str, Any] = {}
59
 
60
- def plan(
61
- self, *, user_query: str, schema_preview: str
62
- ) -> Tuple[str, int, int, float]:
63
- """Generate a query plan for the SQL generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- Args:
66
- user_query: The user's natural language question
67
- schema_preview: Database schema information
 
68
 
69
- Returns:
70
- Tuple of (plan_text, prompt_tokens, completion_tokens, cost)
71
- """
 
 
 
 
 
 
 
 
72
  system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
73
 
74
  Your plan should:
@@ -86,6 +177,9 @@ Be concise but thorough."""
86
  Database Schema:
87
  {schema_preview}
88
 
 
 
 
89
  Create a step-by-step plan to answer this question with SQL."""
90
 
91
  completion = self._create_chat_completion(
@@ -100,6 +194,9 @@ Create a step-by-step plan to answer this question with SQL."""
100
  msg = completion.choices[0].message.content or ""
101
  usage = completion.usage
102
 
 
 
 
103
  if usage:
104
  prompt_tokens = usage.prompt_tokens
105
  completion_tokens = usage.completion_tokens
@@ -110,15 +207,15 @@ Create a step-by-step plan to answer this question with SQL."""
110
  "completion_tokens": completion_tokens,
111
  "cost_usd": cost,
112
  }
113
- return (msg, prompt_tokens, completion_tokens, cost)
114
- else:
115
- self._last_usage = {
116
- "kind": "plan",
117
- "prompt_tokens": 0,
118
- "completion_tokens": 0,
119
- "cost_usd": 0.0,
120
- }
121
- return (msg, 0, 0, 0.0)
122
 
123
  def generate_sql(
124
  self,
@@ -126,21 +223,11 @@ Create a step-by-step plan to answer this question with SQL."""
126
  user_query: str,
127
  schema_preview: str,
128
  plan_text: str,
129
- clarify_answers: dict[str, Any] | None = None,
 
130
  ) -> Tuple[str, str, int, int, float]:
131
- """Generate SQL with improved prompt for Spider benchmark.
132
-
133
- Args:
134
- user_query: The user's natural language question
135
- schema_preview: Database schema information
136
- plan_text: Query execution plan
137
- clarify_answers: Optional additional context_engineering
138
-
139
- Returns:
140
- Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
141
- """
142
- system_prompt = """You are an expert SQL query generator for SQLite databases.
143
- You must follow these STRICT rules to generate clean, simple SQL:
144
 
145
  CRITICAL RULES:
146
  1. Write the SIMPLEST possible SQL that answers the question
@@ -173,6 +260,9 @@ Database Schema:
173
  Query Plan:
174
  {plan_text}
175
 
 
 
 
176
  Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
177
 
178
  Example of what we want:
@@ -199,7 +289,6 @@ Now generate the SQL for the given question:"""
199
  content = text.strip() if text else ""
200
  usage = completion.usage
201
 
202
- # Parse JSON response
203
  try:
204
  parsed = json.loads(content)
205
  except json.JSONDecodeError:
@@ -208,21 +297,21 @@ Now generate the SQL for the given question:"""
208
  if start != -1 and end != -1:
209
  try:
210
  parsed = json.loads(content[start : end + 1])
211
- except Exception:
212
- raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
213
  else:
214
  raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
215
 
216
- sql = (parsed.get("sql") or "").strip()
217
- rationale = parsed.get("rationale") or ""
218
 
219
- # Post-process SQL to ensure simplicity
220
  sql = self._simplify_sql(sql)
221
-
222
  if not sql:
223
  raise ValueError("LLM returned empty 'sql'")
224
 
 
225
  sql_length = len(sql)
 
226
  if usage:
227
  prompt_tokens = usage.prompt_tokens
228
  completion_tokens = usage.completion_tokens
@@ -233,35 +322,33 @@ Now generate the SQL for the given question:"""
233
  "completion_tokens": completion_tokens,
234
  "cost_usd": cost,
235
  "sql_length": sql_length,
 
236
  }
237
  return (sql, rationale, prompt_tokens, completion_tokens, cost)
238
- else:
239
- self._last_usage = {
240
- "kind": "generate",
241
- "prompt_tokens": 0,
242
- "completion_tokens": 0,
243
- "cost_usd": 0.0,
244
- "sql_length": sql_length,
245
- }
246
- return (sql, rationale, 0, 0, 0.0)
 
247
 
248
  def _simplify_sql(self, sql: str) -> str:
249
  """Post-process SQL to remove common unnecessary additions."""
250
  if not sql:
251
  return sql
252
 
253
- # Remove trailing semicolon
254
  sql = sql.rstrip(";")
255
 
256
- # Remove unnecessary table prefixes in simple queries
257
- # e.g., "singer.name" -> "name" when there's only one table
258
  if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
259
  match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
260
  if match:
261
  table = match.group(1)
262
  sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
263
 
264
- # Remove unnecessary DISTINCT in COUNT(*)
265
  sql = re.sub(
266
  r"count\s*\(\s*distinct\s+\*\s*\)",
267
  "count(*)",
@@ -269,7 +356,6 @@ Now generate the SQL for the given question:"""
269
  flags=re.IGNORECASE,
270
  )
271
 
272
- # Remove big default LIMITs that weren't requested
273
  sql = re.sub(
274
  r"\s+limit\s+(100|1000|10000)\b",
275
  "",
@@ -286,16 +372,7 @@ Now generate the SQL for the given question:"""
286
  error_msg: str,
287
  schema_preview: str,
288
  ) -> Tuple[str, int, int, float]:
289
- """Repair SQL with focus on simplicity.
290
-
291
- Args:
292
- sql: Broken SQL query
293
- error_msg: Error message from execution
294
- schema_preview: Database schema information
295
-
296
- Returns:
297
- Tuple of (fixed_sql, prompt_tokens, completion_tokens, cost)
298
- """
299
  system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
300
 
301
  IMPORTANT RULES:
@@ -332,7 +409,6 @@ Return the corrected SQL (keep it simple):"""
332
  text = completion.choices[0].message.content
333
  fixed_sql = text.strip() if text else ""
334
 
335
- # Clean up accidental code fences
336
  if fixed_sql.startswith("```sql"):
337
  fixed_sql = fixed_sql[6:]
338
  if fixed_sql.startswith("```"):
@@ -344,7 +420,6 @@ Return the corrected SQL (keep it simple):"""
344
  fixed_sql = self._simplify_sql(fixed_sql)
345
 
346
  usage = completion.usage
347
-
348
  if usage:
349
  prompt_tokens = usage.prompt_tokens
350
  completion_tokens = usage.completion_tokens
@@ -357,88 +432,12 @@ Return the corrected SQL (keep it simple):"""
357
  "sql_length": len(fixed_sql),
358
  }
359
  return (fixed_sql, prompt_tokens, completion_tokens, cost)
360
- else:
361
- self._last_usage = {
362
- "kind": "repair",
363
- "prompt_tokens": 0,
364
- "completion_tokens": 0,
365
- "cost_usd": 0.0,
366
- "sql_length": len(fixed_sql),
367
- }
368
- return (fixed_sql, 0, 0, 0.0)
369
-
370
- def _estimate_cost(self, usage: Any) -> float:
371
- """Estimate cost based on token usage.
372
 
373
- Args:
374
- usage: OpenAI usage object with token counts
375
-
376
- Returns:
377
- Estimated cost in USD
378
- """
379
- if not usage:
380
- return 0.0
381
-
382
- # Pricing per 1K tokens (adjust based on model)
383
- pricing = {
384
- "gpt-4": {"input": 0.03, "output": 0.06},
385
- "gpt-4-turbo": {"input": 0.01, "output": 0.03},
386
- "gpt-4o": {"input": 0.005, "output": 0.015},
387
- "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
388
- "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
389
  }
390
-
391
- model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
392
-
393
- input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
394
- output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
395
-
396
- return input_cost + output_cost
397
-
398
- def clarify(
399
- self,
400
- *,
401
- user_query: str,
402
- schema_preview: str,
403
- questions: List[str],
404
- ) -> Tuple[str, int, int, float]:
405
- """Clarify ambiguities in the user query.
406
-
407
- Args:
408
- user_query: The user's natural language question
409
- schema_preview: Database schema information
410
- questions: List of clarification questions
411
-
412
- Returns:
413
- Tuple of (answers, prompt_tokens, completion_tokens, cost)
414
- """
415
- system_prompt = """You are a helpful assistant that clarifies SQL query requirements.
416
- Answer the questions clearly and concisely based on the user's query and database schema."""
417
-
418
- user_prompt = f"""User Query: {user_query}
419
-
420
- Database Schema:
421
- {schema_preview}
422
-
423
- Please answer these clarification questions:
424
- {chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
425
-
426
- completion = self._create_chat_completion(
427
- model=self.model,
428
- messages=[
429
- {"role": "system", "content": system_prompt},
430
- {"role": "user", "content": user_prompt},
431
- ],
432
- temperature=0.3,
433
- )
434
-
435
- answers = completion.choices[0].message.content or ""
436
- usage = completion.usage
437
-
438
- if usage:
439
- prompt_tokens = usage.prompt_tokens
440
- completion_tokens = usage.completion_tokens
441
- cost = self._estimate_cost(usage)
442
- return (answers, prompt_tokens, completion_tokens, cost)
443
- else:
444
- return (answers, 0, 0, 0.0)
 
3
  import json
4
  import os
5
  import re
6
+ from typing import Any, Dict, List, Tuple
7
 
8
  from adapters.llm.base import LLMProvider
9
  from openai import OpenAI
 
35
 
36
 
37
  class OpenAIProvider(LLMProvider):
38
+ """OpenAI LLM provider implementation.
39
 
40
+ Goals for this implementation:
41
+ - Keep prompts and behavior as close as possible to the current repo version.
42
+ - Align method signatures + return shapes with the updated LLMProvider Protocol.
43
+ - Provide a lightweight `used_tables` signal for observability/drift checks.
44
+ """
45
 
46
+ PROVIDER_ID = "openai"
 
 
47
 
48
  def __init__(self) -> None:
49
  """Initialize OpenAI client with config from environment."""
 
52
  os.environ["OPENAI_BASE_URL"] = base_url
53
  self.client = OpenAI(timeout=120.0)
54
  self.model = model
 
55
  self._last_usage: dict[str, Any] = {}
56
 
57
+ def get_last_usage(self) -> dict[str, Any]:
58
+ """Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
59
+ return dict(self._last_usage)
60
+
61
+ def _create_chat_completion(self, **kwargs):
62
+ """OpenAI SDK seam for stable unit testing."""
63
+ return self.client.chat.completions.create(**kwargs)
64
+
65
+ # ---------------------------------------------------------------------
66
+ # Table extraction helpers (best-effort; no heavy parsing).
67
+ # ---------------------------------------------------------------------
68
+ def _extract_schema_tables(self, schema_preview: str) -> List[str]:
69
+ """Extract likely table names from the schema preview string."""
70
+ if not schema_preview:
71
+ return []
72
+
73
+ tables: List[str] = []
74
+
75
+ for m in re.finditer(
76
+ r"(?im)^\s*(?:-\s*)?table\s*[: ]\s*([A-Za-z_][A-Za-z0-9_]*)\b",
77
+ schema_preview,
78
+ ):
79
+ tables.append(m.group(1))
80
+
81
+ for m in re.finditer(
82
+ r"(?im)^\s*create\s+table\s+`?([A-Za-z_][A-Za-z0-9_]*)`?\b", schema_preview
83
+ ):
84
+ tables.append(m.group(1))
85
+
86
+ seen = set()
87
+ uniq: List[str] = []
88
+ for t in tables:
89
+ if t not in seen:
90
+ uniq.append(t)
91
+ seen.add(t)
92
+ return uniq
93
+
94
+ def _extract_tables_from_sql(self, sql: str) -> List[str]:
95
+ """Very lightweight table extraction from FROM/JOIN clauses."""
96
+ if not sql:
97
+ return []
98
+ pairs = re.findall(
99
+ r"\bfrom\s+([A-Za-z_][A-Za-z0-9_]*)|\bjoin\s+([A-Za-z_][A-Za-z0-9_]*)",
100
+ sql,
101
+ flags=re.IGNORECASE,
102
+ )
103
+ out: List[str] = []
104
+ for t1, t2 in pairs:
105
+ if t1:
106
+ out.append(t1)
107
+ if t2:
108
+ out.append(t2)
109
+
110
+ seen = set()
111
+ uniq: List[str] = []
112
+ for t in out:
113
+ if t not in seen:
114
+ uniq.append(t)
115
+ seen.add(t)
116
+ return uniq
117
+
118
+ def _extract_used_tables_from_plan(
119
+ self, plan_text: str, schema_preview: str
120
+ ) -> List[str]:
121
+ """Best-effort used table list from plan text by intersecting with schema table names."""
122
+ candidates = self._extract_schema_tables(schema_preview)
123
+ if not candidates or not plan_text:
124
+ return []
125
+ used: List[str] = []
126
+ for t in candidates:
127
+ if re.search(rf"\b{re.escape(t)}\b", plan_text, flags=re.IGNORECASE):
128
+ used.append(t)
129
+ return used
130
+
131
+ # ---------------------------------------------------------------------
132
+ # Cost estimation
133
+ # ---------------------------------------------------------------------
134
+ def _estimate_cost(self, usage: Any) -> float:
135
+ """Estimate cost based on token usage."""
136
+ if not usage:
137
+ return 0.0
138
+
139
+ pricing = {
140
+ "gpt-4": {"input": 0.03, "output": 0.06},
141
+ "gpt-4-turbo": {"input": 0.01, "output": 0.03},
142
+ "gpt-4o": {"input": 0.005, "output": 0.015},
143
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
144
+ "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
145
+ }
146
 
147
+ model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
148
+ input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
149
+ output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
150
+ return input_cost + output_cost
151
 
152
+ # ---------------------------------------------------------------------
153
+ # LLMProvider API
154
+ # ---------------------------------------------------------------------
155
+ def plan(
156
+ self,
157
+ *,
158
+ user_query: str,
159
+ schema_preview: str,
160
+ constraints: List[str] | None = None,
161
+ ) -> Tuple[str, List[str], int, int, float]:
162
+ """Return (plan_text, used_tables, token_in, token_out, cost_usd)."""
163
  system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
164
 
165
  Your plan should:
 
177
  Database Schema:
178
  {schema_preview}
179
 
180
+ Constraints:
181
+ {constraints or []}
182
+
183
  Create a step-by-step plan to answer this question with SQL."""
184
 
185
  completion = self._create_chat_completion(
 
194
  msg = completion.choices[0].message.content or ""
195
  usage = completion.usage
196
 
197
+ plan_text = msg.strip()
198
+ used_tables = self._extract_used_tables_from_plan(plan_text, schema_preview)
199
+
200
  if usage:
201
  prompt_tokens = usage.prompt_tokens
202
  completion_tokens = usage.completion_tokens
 
207
  "completion_tokens": completion_tokens,
208
  "cost_usd": cost,
209
  }
210
+ return (plan_text, used_tables, prompt_tokens, completion_tokens, cost)
211
+
212
+ self._last_usage = {
213
+ "kind": "plan",
214
+ "prompt_tokens": 0,
215
+ "completion_tokens": 0,
216
+ "cost_usd": 0.0,
217
+ }
218
+ return (plan_text, used_tables, 0, 0, 0.0)
219
 
220
  def generate_sql(
221
  self,
 
223
  user_query: str,
224
  schema_preview: str,
225
  plan_text: str,
226
+ constraints: List[str] | None = None,
227
+ clarify_answers: Dict[str, Any] | None = None,
228
  ) -> Tuple[str, str, int, int, float]:
229
+ """Return (sql, rationale, token_in, token_out, cost_usd)."""
230
+ system_prompt = """You are an expert SQL generator.
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  CRITICAL RULES:
233
  1. Write the SIMPLEST possible SQL that answers the question
 
260
  Query Plan:
261
  {plan_text}
262
 
263
+ Constraints:
264
+ {constraints or []}
265
+
266
  Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
267
 
268
  Example of what we want:
 
289
  content = text.strip() if text else ""
290
  usage = completion.usage
291
 
 
292
  try:
293
  parsed = json.loads(content)
294
  except json.JSONDecodeError:
 
297
  if start != -1 and end != -1:
298
  try:
299
  parsed = json.loads(content[start : end + 1])
300
+ except Exception as e:
301
+ raise ValueError(f"Invalid LLM JSON output: {content[:200]}") from e
302
  else:
303
  raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
304
 
305
+ sql = str(parsed.get("sql") or "").strip()
306
+ rationale = str(parsed.get("rationale") or "")
307
 
 
308
  sql = self._simplify_sql(sql)
 
309
  if not sql:
310
  raise ValueError("LLM returned empty 'sql'")
311
 
312
+ used_tables = self._extract_tables_from_sql(sql)
313
  sql_length = len(sql)
314
+
315
  if usage:
316
  prompt_tokens = usage.prompt_tokens
317
  completion_tokens = usage.completion_tokens
 
322
  "completion_tokens": completion_tokens,
323
  "cost_usd": cost,
324
  "sql_length": sql_length,
325
+ "used_tables": used_tables,
326
  }
327
  return (sql, rationale, prompt_tokens, completion_tokens, cost)
328
+
329
+ self._last_usage = {
330
+ "kind": "generate",
331
+ "prompt_tokens": 0,
332
+ "completion_tokens": 0,
333
+ "cost_usd": 0.0,
334
+ "sql_length": sql_length,
335
+ "used_tables": used_tables,
336
+ }
337
+ return (sql, rationale, 0, 0, 0.0)
338
 
339
  def _simplify_sql(self, sql: str) -> str:
340
  """Post-process SQL to remove common unnecessary additions."""
341
  if not sql:
342
  return sql
343
 
 
344
  sql = sql.rstrip(";")
345
 
 
 
346
  if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
347
  match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
348
  if match:
349
  table = match.group(1)
350
  sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
351
 
 
352
  sql = re.sub(
353
  r"count\s*\(\s*distinct\s+\*\s*\)",
354
  "count(*)",
 
356
  flags=re.IGNORECASE,
357
  )
358
 
 
359
  sql = re.sub(
360
  r"\s+limit\s+(100|1000|10000)\b",
361
  "",
 
372
  error_msg: str,
373
  schema_preview: str,
374
  ) -> Tuple[str, int, int, float]:
375
+ """Return (patched_sql, token_in, token_out, cost_usd)."""
 
 
 
 
 
 
 
 
 
376
  system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
377
 
378
  IMPORTANT RULES:
 
409
  text = completion.choices[0].message.content
410
  fixed_sql = text.strip() if text else ""
411
 
 
412
  if fixed_sql.startswith("```sql"):
413
  fixed_sql = fixed_sql[6:]
414
  if fixed_sql.startswith("```"):
 
420
  fixed_sql = self._simplify_sql(fixed_sql)
421
 
422
  usage = completion.usage
 
423
  if usage:
424
  prompt_tokens = usage.prompt_tokens
425
  completion_tokens = usage.completion_tokens
 
432
  "sql_length": len(fixed_sql),
433
  }
434
  return (fixed_sql, prompt_tokens, completion_tokens, cost)
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ self._last_usage = {
437
+ "kind": "repair",
438
+ "prompt_tokens": 0,
439
+ "completion_tokens": 0,
440
+ "cost_usd": 0.0,
441
+ "sql_length": len(fixed_sql),
 
 
 
 
 
 
 
 
 
 
442
  }
443
+ return (fixed_sql, 0, 0, 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nl2sql/errors/codes.py CHANGED
@@ -14,6 +14,7 @@ class ErrorCode(str, Enum):
14
  # --- Executor / DB ---
15
  DB_LOCKED = "DB_LOCKED"
16
  DB_TIMEOUT = "DB_TIMEOUT"
 
17
 
18
  # --- LLM ---
19
  LLM_TIMEOUT = "LLM_TIMEOUT"
 
14
  # --- Executor / DB ---
15
  DB_LOCKED = "DB_LOCKED"
16
  DB_TIMEOUT = "DB_TIMEOUT"
17
+ LLM_FAILURE = "LLM_FAILURE"
18
 
19
  # --- LLM ---
20
  LLM_TIMEOUT = "LLM_TIMEOUT"
nl2sql/generator.py CHANGED
@@ -20,7 +20,9 @@ class Generator:
20
  user_query: str,
21
  schema_preview: str,
22
  plan_text: str,
 
23
  clarify_answers: Optional[Dict[str, Any]] = None,
 
24
  ) -> StageResult:
25
  t0 = time.perf_counter()
26
 
@@ -29,10 +31,11 @@ class Generator:
29
  user_query=user_query,
30
  schema_preview=schema_preview,
31
  plan_text=plan_text,
 
32
  clarify_answers=clarify_answers or {},
33
  )
34
  except Exception as e:
35
- # Provider/transport errors or unexpected runtime issues.
36
  return StageResult(
37
  ok=False,
38
  error=[f"Generator failed: {e}"],
@@ -40,18 +43,22 @@ class Generator:
40
  trace=None,
41
  )
42
 
43
- # Contract: expect a 5-tuple (sql, rationale, token_in, token_out, cost_usd)
44
- if not isinstance(res, tuple) or len(res) != 5:
45
  return StageResult(
46
  ok=False,
47
  error=[
48
- "Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
49
  ],
50
  error_code=ErrorCode.LLM_BAD_OUTPUT,
51
  trace=None,
52
  )
53
 
54
- sql, rationale, t_in, t_out, cost = res
 
 
 
 
 
55
 
56
  # Type/shape checks
57
  if not isinstance(sql, str) or not sql.strip():
@@ -73,18 +80,20 @@ class Generator:
73
 
74
  # Normalize rationale to a string
75
  rationale = rationale or ""
 
76
  trace = StageTrace(
77
  stage=self.name,
 
78
  duration_ms=(time.perf_counter() - t0) * 1000.0,
79
  token_in=t_in,
80
  token_out=t_out,
81
  cost_usd=cost,
82
- notes={"rationale_len": len(rationale)},
83
  )
84
 
85
  return StageResult(
86
  ok=True,
87
- data={"sql": sql, "rationale": rationale},
88
  trace=trace,
89
  error_code=None,
90
  retryable=None,
 
20
  user_query: str,
21
  schema_preview: str,
22
  plan_text: str,
23
+ constraints: Optional[list[str]] = None,
24
  clarify_answers: Optional[Dict[str, Any]] = None,
25
+ traces: Optional[list[dict]] = None,
26
  ) -> StageResult:
27
  t0 = time.perf_counter()
28
 
 
31
  user_query=user_query,
32
  schema_preview=schema_preview,
33
  plan_text=plan_text,
34
+ constraints=constraints or [],
35
  clarify_answers=clarify_answers or {},
36
  )
37
  except Exception as e:
38
+ # Provider/transport errors or unexpected runtime exceptions.
39
  return StageResult(
40
  ok=False,
41
  error=[f"Generator failed: {e}"],
 
43
  trace=None,
44
  )
45
 
46
+ if not isinstance(res, tuple) or len(res) not in (5, 6):
 
47
  return StageResult(
48
  ok=False,
49
  error=[
50
+ "Generator contract violation: expected 5/6-tuple (sql, rationale, [used_tables], t_in, t_out, cost)"
51
  ],
52
  error_code=ErrorCode.LLM_BAD_OUTPUT,
53
  trace=None,
54
  )
55
 
56
+ used_tables: list[str] = []
57
+
58
+ if len(res) == 6:
59
+ sql, rationale, used_tables, t_in, t_out, cost = res
60
+ else:
61
+ sql, rationale, t_in, t_out, cost = res
62
 
63
  # Type/shape checks
64
  if not isinstance(sql, str) or not sql.strip():
 
80
 
81
  # Normalize rationale to a string
82
  rationale = rationale or ""
83
+
84
  trace = StageTrace(
85
  stage=self.name,
86
+ summary="Generated SQL",
87
  duration_ms=(time.perf_counter() - t0) * 1000.0,
88
  token_in=t_in,
89
  token_out=t_out,
90
  cost_usd=cost,
91
+ notes={"rationale_len": len(rationale), "used_tables": used_tables},
92
  )
93
 
94
  return StageResult(
95
  ok=True,
96
+ data={"sql": sql, "rationale": rationale, "used_tables": used_tables},
97
  trace=trace,
98
  error_code=None,
99
  retryable=None,
nl2sql/pipeline.py CHANGED
@@ -276,6 +276,17 @@ class Pipeline:
276
  details: List[str] = []
277
  exec_result: Dict[str, Any] = {}
278
 
 
 
 
 
 
 
 
 
 
 
 
279
  def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
280
  traces.append(
281
  self._mk_trace(
@@ -411,6 +422,33 @@ class Pipeline:
411
  sql = (r_gen.data or {}).get("sql")
412
  rationale = (r_gen.data or {}).get("rationale")
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  # Guard: empty SQL
415
  if not sql or not str(sql).strip():
416
  pipeline_runs_total.labels(status="error").inc()
@@ -485,44 +523,39 @@ class Pipeline:
485
  if r_exec.ok and isinstance(r_exec.data, dict):
486
  exec_result = dict(r_exec.data)
487
 
488
- # --- 6) verifier (run with repair for consistency) ---
489
- t0 = time.perf_counter()
490
- r_ver = self._run_with_repair(
491
- "verifier",
492
- self._call_verifier,
493
- repair_input_builder=self._sql_repair_input_builder,
494
- max_attempts=1,
495
- sql=sql,
496
- exec_result=(r_exec.data or {}),
497
- traces=traces,
498
- )
499
- dt = (time.perf_counter() - t0) * 1000.0
500
- stage_duration_ms.labels("verifier").observe(dt)
 
 
 
 
501
 
502
- # Traces
503
- traces.extend(self._trace_list(r_ver))
504
- if not getattr(r_ver, "trace", None):
505
- _fallback_trace("verifier", dt, r_ver.ok)
 
506
 
507
- # If verifier (or its repair) produced a new SQL, consume it
508
- if r_ver.data and isinstance(r_ver.data, dict):
509
- repaired_sql = r_ver.data.get("sql")
510
- if repaired_sql:
511
- sql = repaired_sql
512
 
513
  # Verified flag
514
- verified = (
515
- bool(
516
- r_ver.data
517
- and isinstance(r_ver.data, dict)
518
- and r_ver.data.get("verified")
519
- )
520
- or r_ver.ok
521
- )
522
 
523
  # consume repaired SQL from verifier if any
524
- if r_ver.data and "sql" in r_ver.data and r_ver.data["sql"]:
525
- sql = r_ver.data["sql"]
 
526
 
527
  # --- 7) repair loop (if not verified) ---
528
  if not verified:
@@ -534,11 +567,12 @@ class Pipeline:
534
  self.repair.run,
535
  sql=sql,
536
  error_msg="; ".join(details or ["unknown"]),
537
- schema_preview=schema_preview,
538
  )
539
  dt = (time.perf_counter() - t0) * 1000.0
540
  stage_duration_ms.labels("repair").observe(dt)
541
  traces.extend(self._trace_list(r_fix))
 
542
  if not getattr(r_fix, "trace", None):
543
  _fallback_trace("repair", dt, r_fix.ok)
544
  if not r_fix.ok:
@@ -553,6 +587,7 @@ class Pipeline:
553
  dt2 = (time.perf_counter() - t0) * 1000.0
554
  stage_duration_ms.labels("safety").observe(dt2)
555
  traces.extend(self._trace_list(r_safe2))
 
556
  if not getattr(r_safe2, "trace", None):
557
  _fallback_trace("safety", dt2, r_safe2.ok)
558
  if not r_safe2.ok:
@@ -567,6 +602,7 @@ class Pipeline:
567
  dt2 = (time.perf_counter() - t0) * 1000.0
568
  stage_duration_ms.labels("executor").observe(dt2)
569
  traces.extend(self._trace_list(r_exec2))
 
570
  if not getattr(r_exec2, "trace", None):
571
  _fallback_trace("executor", dt2, r_exec2.ok)
572
  if not r_exec2.ok:
@@ -586,11 +622,10 @@ class Pipeline:
586
  dt2 = (time.perf_counter() - t0) * 1000.0
587
  stage_duration_ms.labels("verifier").observe(dt2)
588
  traces.extend(self._trace_list(r_ver2))
 
589
  if not getattr(r_ver2, "trace", None):
590
  _fallback_trace("verifier", dt2, r_ver2.ok)
591
- verified = (
592
- bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
593
- )
594
  if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
595
  sql = r_ver2.data["sql"]
596
  if verified:
 
276
  details: List[str] = []
277
  exec_result: Dict[str, Any] = {}
278
 
279
+ def _tag_last_trace_attempt(stage_name: str, attempt: int) -> None:
280
+ # Attach attempt metadata to the most recent trace entry for this stage.
281
+ for t in reversed(traces):
282
+ if t.get("stage") == stage_name:
283
+ notes = t.get("notes") or {}
284
+ if not isinstance(notes, dict):
285
+ notes = {}
286
+ notes["attempt"] = attempt
287
+ t["notes"] = notes
288
+ return
289
+
290
  def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
291
  traces.append(
292
  self._mk_trace(
 
422
  sql = (r_gen.data or {}).get("sql")
423
  rationale = (r_gen.data or {}).get("rationale")
424
 
425
+ # --- schema drift signal (planner vs generator table usage)
426
+ planner_used_tables = (
427
+ (r_plan.data or {}).get("used_tables")
428
+ or (r_plan.data or {}).get("tables")
429
+ or []
430
+ )
431
+ generator_used_tables = (
432
+ (r_gen.data or {}).get("used_tables")
433
+ or (r_gen.data or {}).get("tables")
434
+ or []
435
+ )
436
+ planner_set = set(planner_used_tables)
437
+ generator_set = set(generator_used_tables)
438
+ schema_drift = bool(generator_set - planner_set)
439
+ traces.append(
440
+ self._mk_trace(
441
+ stage="schema_drift_check",
442
+ duration_ms=0.0,
443
+ summary="compare planner vs generator table usage",
444
+ notes={
445
+ "planner_used_tables": sorted(planner_set),
446
+ "generator_used_tables": sorted(generator_set),
447
+ "schema_drift": schema_drift,
448
+ },
449
+ )
450
+ )
451
+
452
  # Guard: empty SQL
453
  if not sql or not str(sql).strip():
454
  pipeline_runs_total.labels(status="error").inc()
 
523
  if r_exec.ok and isinstance(r_exec.data, dict):
524
  exec_result = dict(r_exec.data)
525
 
526
+ # --- 6) verifier (only if execution succeeded) ---
527
+ r_ver = None
528
+ if r_exec.ok:
529
+ t0 = time.perf_counter()
530
+ r_ver = self._run_with_repair(
531
+ "verifier",
532
+ self._call_verifier,
533
+ repair_input_builder=self._sql_repair_input_builder,
534
+ max_attempts=1,
535
+ sql=sql,
536
+ exec_result=(r_exec.data or {}),
537
+ traces=traces,
538
+ )
539
+ dt = (time.perf_counter() - t0) * 1000.0
540
+ stage_duration_ms.labels("verifier").observe(dt)
541
+
542
+ # Traces
543
 
544
+ # If verifier (or its repair) produced a new SQL, consume it
545
+ if r_ver.data and isinstance(r_ver.data, dict):
546
+ repaired_sql = r_ver.data.get("sql")
547
+ if repaired_sql:
548
+ sql = repaired_sql
549
 
550
+ data = r_ver.data if (r_ver and isinstance(r_ver.data, dict)) else {}
 
 
 
 
551
 
552
  # Verified flag
553
+ verified = bool(data.get("verified") is True)
 
 
 
 
 
 
 
554
 
555
  # consume repaired SQL from verifier if any
556
+ repaired_sql = data.get("sql")
557
+ if repaired_sql:
558
+ sql = repaired_sql
559
 
560
  # --- 7) repair loop (if not verified) ---
561
  if not verified:
 
567
  self.repair.run,
568
  sql=sql,
569
  error_msg="; ".join(details or ["unknown"]),
570
+ schema_preview=schema_for_llm,
571
  )
572
  dt = (time.perf_counter() - t0) * 1000.0
573
  stage_duration_ms.labels("repair").observe(dt)
574
  traces.extend(self._trace_list(r_fix))
575
+ _tag_last_trace_attempt("repair", _attempt)
576
  if not getattr(r_fix, "trace", None):
577
  _fallback_trace("repair", dt, r_fix.ok)
578
  if not r_fix.ok:
 
587
  dt2 = (time.perf_counter() - t0) * 1000.0
588
  stage_duration_ms.labels("safety").observe(dt2)
589
  traces.extend(self._trace_list(r_safe2))
590
+ _tag_last_trace_attempt("safety", _attempt)
591
  if not getattr(r_safe2, "trace", None):
592
  _fallback_trace("safety", dt2, r_safe2.ok)
593
  if not r_safe2.ok:
 
602
  dt2 = (time.perf_counter() - t0) * 1000.0
603
  stage_duration_ms.labels("executor").observe(dt2)
604
  traces.extend(self._trace_list(r_exec2))
605
+ _tag_last_trace_attempt("executor", _attempt)
606
  if not getattr(r_exec2, "trace", None):
607
  _fallback_trace("executor", dt2, r_exec2.ok)
608
  if not r_exec2.ok:
 
622
  dt2 = (time.perf_counter() - t0) * 1000.0
623
  stage_duration_ms.labels("verifier").observe(dt2)
624
  traces.extend(self._trace_list(r_ver2))
625
+ _tag_last_trace_attempt("verifier", _attempt)
626
  if not getattr(r_ver2, "trace", None):
627
  _fallback_trace("verifier", dt2, r_ver2.ok)
628
+ verified = bool(r_ver2.data and r_ver2.data.get("verified") is True)
 
 
629
  if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
630
  sql = r_ver2.data["sql"]
631
  if verified:
nl2sql/planner.py CHANGED
@@ -6,6 +6,23 @@ from typing import Any, Dict, List, Tuple, Optional
6
  __all__ = ["Planner"]
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # --------- Heuristic schema trimming (safe, mypy-clean) ---------
10
  def _tokenize_lower(s: str) -> List[str]:
11
  return re.findall(r"[a-z_]+", (s or "").lower())
@@ -14,41 +31,33 @@ def _tokenize_lower(s: str) -> List[str]:
14
  def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
15
  """
16
  Parse plain-text schema into [(table_name, lines)] blocks,
17
- supporting both 'Table: name' and 'CREATE TABLE name (' styles.
 
 
 
18
  """
19
  blocks: List[Tuple[str, List[str]]] = []
20
  cur_name: Optional[str] = None
21
  cur_lines: List[str] = []
22
 
23
- def _flush() -> None:
24
  nonlocal cur_name, cur_lines
25
- if cur_name is not None and cur_lines:
26
- blocks.append((cur_name, cur_lines[:]))
27
  cur_name, cur_lines = None, []
28
 
29
- for line in (schema_text or "").splitlines():
30
- m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE)
31
- m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE)
32
-
33
- started = False
34
- name: Optional[str] = None
35
- if m is not None:
36
- name = m.group(1)
37
- started = True
38
- elif m2 is not None:
39
- name = m2.group(1)
40
- started = True
41
-
42
- if started and name:
43
  _flush()
44
- cur_name = name
45
- cur_lines.append(line)
46
  else:
47
  if cur_name is not None:
48
- cur_lines.append(line)
49
-
50
- if cur_name is not None and line.strip().endswith(");"):
51
- _flush()
52
 
53
  _flush()
54
  return blocks
@@ -64,29 +73,22 @@ def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
64
  q_toks = set(_tokenize_lower(question))
65
  scored: List[Tuple[int, str, List[str]]] = []
66
  for name, lines in blocks:
67
- score = sum(1 for w in _tokenize_lower(name) if w in q_toks)
68
- cols_line = " ".join(lines)
69
- cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
70
- score += min(2, sum(1 for c in cols if c.lower() in q_toks))
71
  scored.append((score, name, lines))
72
 
73
- scored.sort(key=lambda t: t[0], reverse=True)
74
- keep = [b for b in scored[: max(1, k)] if b[0] > 0]
75
- if not keep:
76
- keep = scored[: max(1, k)]
77
-
78
  out_lines: List[str] = []
79
- for _, _, lines in keep:
80
  out_lines.extend(lines)
81
- if lines and lines[-1].strip() != "":
82
- out_lines.append("")
83
- trimmed = "\n".join(out_lines).strip()
84
- return trimmed if trimmed else schema_text
85
  except Exception:
86
  return schema_text
87
 
88
 
89
- # ------------------------------ Planner ------------------------------
90
  class Planner:
91
  """Planner wrapper around the LLM provider."""
92
 
@@ -95,26 +97,65 @@ class Planner:
95
  # ensure model_id is always a str (for mypy)
96
  self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
97
  # in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
98
- self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {}
99
-
100
- def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
101
- trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3)
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  key: tuple[str, int, int] = (
104
  self.model_id,
105
  hash(user_query or ""),
106
- hash(trimmed),
107
  )
 
108
  if key in self._plan_cache:
109
- plan_text, pin, pout, cost = self._plan_cache[key]
110
  else:
111
- plan_text, pin, pout, cost = self.llm.plan(
112
- user_query=user_query, schema_preview=trimmed
113
- )
114
- self._plan_cache[key] = (plan_text, pin, pout, cost)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  return {
117
  "plan": plan_text,
 
118
  "usage": {
119
  "prompt_tokens": pin,
120
  "completion_tokens": pout,
 
6
  __all__ = ["Planner"]
7
 
8
 
9
+ def _extract_table_names_from_schema(schema_text: str) -> List[str]:
10
+ """Best-effort table name extraction from schema preview."""
11
+ if not schema_text:
12
+ return []
13
+ names = re.findall(
14
+ r"(?im)^\s*create\s+table\s+`?([A-Za-z_][A-Za-z0-9_]*)`?\b", schema_text
15
+ )
16
+ # de-dup preserving order
17
+ seen: set[str] = set()
18
+ out: List[str] = []
19
+ for n in names:
20
+ if n not in seen:
21
+ out.append(n)
22
+ seen.add(n)
23
+ return out
24
+
25
+
26
  # --------- Heuristic schema trimming (safe, mypy-clean) ---------
27
  def _tokenize_lower(s: str) -> List[str]:
28
  return re.findall(r"[a-z_]+", (s or "").lower())
 
31
  def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
32
  """
33
  Parse plain-text schema into [(table_name, lines)] blocks,
34
+ assuming SQLite preview format like:
35
+ Table: users
36
+ - id
37
+ - name
38
  """
39
  blocks: List[Tuple[str, List[str]]] = []
40
  cur_name: Optional[str] = None
41
  cur_lines: List[str] = []
42
 
43
+ def _flush():
44
  nonlocal cur_name, cur_lines
45
+ if cur_name is not None:
46
+ blocks.append((cur_name, cur_lines))
47
  cur_name, cur_lines = None, []
48
 
49
+ for raw in (schema_text or "").splitlines():
50
+ line = raw.strip()
51
+ if not line:
52
+ continue
53
+ m = re.match(r"^table:\s*([a-zA-Z0-9_]+)\s*$", line, re.IGNORECASE)
54
+ if m:
 
 
 
 
 
 
 
 
55
  _flush()
56
+ cur_name = m.group(1)
57
+ cur_lines = [raw]
58
  else:
59
  if cur_name is not None:
60
+ cur_lines.append(raw)
 
 
 
61
 
62
  _flush()
63
  return blocks
 
73
  q_toks = set(_tokenize_lower(question))
74
  scored: List[Tuple[int, str, List[str]]] = []
75
  for name, lines in blocks:
76
+ score = sum(1 for tok in _tokenize_lower(" ".join(lines)) if tok in q_toks)
 
 
 
77
  scored.append((score, name, lines))
78
 
79
+ scored.sort(key=lambda x: (-x[0], x[1]))
80
+ top = scored[:k]
81
+ # Keep stable order by original appearance? We'll keep by score then name for determinism.
 
 
82
  out_lines: List[str] = []
83
+ for _, _, lines in top:
84
  out_lines.extend(lines)
85
+ out_lines.append("") # spacing
86
+
87
+ return "\n".join(out_lines).strip() if out_lines else schema_text
 
88
  except Exception:
89
  return schema_text
90
 
91
 
 
92
  class Planner:
93
  """Planner wrapper around the LLM provider."""
94
 
 
97
  # ensure model_id is always a str (for mypy)
98
  self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
99
  # in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
100
+ self._plan_cache: dict[
101
+ tuple[str, int, int], tuple[str, List[str], int, int, float]
102
+ ] = {}
103
+
104
+ def run(
105
+ self,
106
+ *,
107
+ user_query: str,
108
+ schema_preview: str,
109
+ constraints: Optional[List[str]] = None,
110
+ traces: Optional[List[dict]] = None,
111
+ ) -> Dict[str, Any]:
112
+ """Plan the query. Assumes schema_preview is already budgeted upstream."""
113
+ schema_preview = schema_preview or ""
114
+ constraints = constraints or []
115
 
116
  key: tuple[str, int, int] = (
117
  self.model_id,
118
  hash(user_query or ""),
119
+ hash(schema_preview),
120
  )
121
+
122
  if key in self._plan_cache:
123
+ plan_text, used_tables, pin, pout, cost = self._plan_cache[key]
124
  else:
125
+ # Call provider with backward-compatible kwargs
126
+ try:
127
+ res = self.llm.plan(
128
+ user_query=user_query,
129
+ schema_preview=schema_preview,
130
+ constraints=constraints,
131
+ )
132
+ except TypeError:
133
+ # Older fakes/providers may not accept `constraints`
134
+ res = self.llm.plan(
135
+ user_query=user_query,
136
+ schema_preview=schema_preview,
137
+ )
138
+
139
+ if not isinstance(res, tuple):
140
+ raise TypeError("LLM plan() must return a tuple")
141
+
142
+ if len(res) == 5:
143
+ plan_text, used_tables, pin, pout, cost = res
144
+ elif len(res) == 4:
145
+ plan_text, pin, pout, cost = res
146
+ used_tables = _extract_table_names_from_schema(schema_preview)
147
+ else:
148
+ raise TypeError("LLM plan() must return 4- or 5-tuple")
149
+
150
+ # Ensure used_tables is always a list[str]
151
+ if not isinstance(used_tables, list):
152
+ used_tables = _extract_table_names_from_schema(schema_preview)
153
+
154
+ self._plan_cache[key] = (plan_text, used_tables, pin, pout, cost)
155
 
156
  return {
157
  "plan": plan_text,
158
+ "used_tables": used_tables,
159
  "usage": {
160
  "prompt_tokens": pin,
161
  "completion_tokens": pout,
nl2sql/prompts/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt contracts for LLM-facing stages."""
2
+
3
+ from .contracts import (
4
+ PlannerPromptInput,
5
+ PlannerPromptOutput,
6
+ GeneratorPromptInput,
7
+ GeneratorPromptOutput,
8
+ )
9
+
10
+ __all__ = [
11
+ "PlannerPromptInput",
12
+ "PlannerPromptOutput",
13
+ "GeneratorPromptInput",
14
+ "GeneratorPromptOutput",
15
+ ]
nl2sql/prompts/contracts.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional
5
+
6
+
7
+ # NOTE:
8
+ # These are *prompt contracts* (input/output shapes) for LLM-facing stages.
9
+ # They are intentionally lightweight to keep Block C minimal and low-risk.
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class PlannerPromptInput:
14
+ user_query: str
15
+ schema_preview: str # already budgeted at pipeline boundary
16
+ constraints: List[str]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class PlannerPromptOutput:
21
+ plan: str
22
+ used_tables: List[str]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class GeneratorPromptInput:
27
+ user_query: str
28
+ schema_preview: str # already budgeted at pipeline boundary
29
+ plan: str
30
+ constraints: List[str]
31
+ clarify_answers: Optional[Dict[str, Any]] = None
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class GeneratorPromptOutput:
36
+ sql: str
37
+ rationale: str
38
+ used_tables: List[str]