Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
4e73462
1
Parent(s):
8e8639a
Sync from GitHub main @ 517739c210f47d8dcf880b0b6b7501a464d6ef4f
Browse files- adapters/llm/base.py +10 -4
- adapters/llm/openai_provider.py +159 -160
- nl2sql/errors/codes.py +1 -0
- nl2sql/generator.py +16 -7
- nl2sql/pipeline.py +71 -36
- nl2sql/planner.py +90 -49
- nl2sql/prompts/__init__.py +15 -0
- nl2sql/prompts/contracts.py +38 -0
adapters/llm/base.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class LLMProvider(Protocol):
|
| 6 |
PROVIDER_ID: str
|
| 7 |
|
| 8 |
def plan(
|
| 9 |
-
self,
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def generate_sql(
|
| 14 |
self,
|
|
@@ -16,6 +21,7 @@ class LLMProvider(Protocol):
|
|
| 16 |
user_query: str,
|
| 17 |
schema_preview: str,
|
| 18 |
plan_text: str,
|
|
|
|
| 19 |
clarify_answers: Dict[str, Any] | None = None,
|
| 20 |
) -> Tuple[str, str, int, int, float]:
|
| 21 |
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Protocol, Tuple
|
| 4 |
|
| 5 |
|
| 6 |
class LLMProvider(Protocol):
|
| 7 |
PROVIDER_ID: str
|
| 8 |
|
| 9 |
def plan(
|
| 10 |
+
self,
|
| 11 |
+
*,
|
| 12 |
+
user_query: str,
|
| 13 |
+
schema_preview: str,
|
| 14 |
+
constraints: List[str] | None = None,
|
| 15 |
+
) -> Tuple[str, List[str], int, int, float]:
|
| 16 |
+
"""Return (plan_text, used_tables, token_in, token_out, cost_usd)."""
|
| 17 |
|
| 18 |
def generate_sql(
|
| 19 |
self,
|
|
|
|
| 21 |
user_query: str,
|
| 22 |
schema_preview: str,
|
| 23 |
plan_text: str,
|
| 24 |
+
constraints: List[str] | None = None,
|
| 25 |
clarify_answers: Dict[str, Any] | None = None,
|
| 26 |
) -> Tuple[str, str, int, int, float]:
|
| 27 |
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
adapters/llm/openai_provider.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
-
from typing import Any, List, Tuple
|
| 7 |
|
| 8 |
from adapters.llm.base import LLMProvider
|
| 9 |
from openai import OpenAI
|
|
@@ -35,17 +35,15 @@ def _resolve_api_config() -> tuple[str, str, str]:
|
|
| 35 |
|
| 36 |
|
| 37 |
class OpenAIProvider(LLMProvider):
|
| 38 |
-
"""OpenAI LLM provider implementation.
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
"""OpenAI SDK seam for stable unit testing."""
|
| 48 |
-
return self.client.chat.completions.create(**kwargs)
|
| 49 |
|
| 50 |
def __init__(self) -> None:
|
| 51 |
"""Initialize OpenAI client with config from environment."""
|
|
@@ -54,21 +52,114 @@ class OpenAIProvider(LLMProvider):
|
|
| 54 |
os.environ["OPENAI_BASE_URL"] = base_url
|
| 55 |
self.client = OpenAI(timeout=120.0)
|
| 56 |
self.model = model
|
| 57 |
-
# last call usage/metadata for tracing
|
| 58 |
self._last_usage: dict[str, Any] = {}
|
| 59 |
|
| 60 |
-
def
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
|
| 73 |
|
| 74 |
Your plan should:
|
|
@@ -86,6 +177,9 @@ Be concise but thorough."""
|
|
| 86 |
Database Schema:
|
| 87 |
{schema_preview}
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
Create a step-by-step plan to answer this question with SQL."""
|
| 90 |
|
| 91 |
completion = self._create_chat_completion(
|
|
@@ -100,6 +194,9 @@ Create a step-by-step plan to answer this question with SQL."""
|
|
| 100 |
msg = completion.choices[0].message.content or ""
|
| 101 |
usage = completion.usage
|
| 102 |
|
|
|
|
|
|
|
|
|
|
| 103 |
if usage:
|
| 104 |
prompt_tokens = usage.prompt_tokens
|
| 105 |
completion_tokens = usage.completion_tokens
|
|
@@ -110,15 +207,15 @@ Create a step-by-step plan to answer this question with SQL."""
|
|
| 110 |
"completion_tokens": completion_tokens,
|
| 111 |
"cost_usd": cost,
|
| 112 |
}
|
| 113 |
-
return (
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
|
| 123 |
def generate_sql(
|
| 124 |
self,
|
|
@@ -126,21 +223,11 @@ Create a step-by-step plan to answer this question with SQL."""
|
|
| 126 |
user_query: str,
|
| 127 |
schema_preview: str,
|
| 128 |
plan_text: str,
|
| 129 |
-
|
|
|
|
| 130 |
) -> Tuple[str, str, int, int, float]:
|
| 131 |
-
"""
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
user_query: The user's natural language question
|
| 135 |
-
schema_preview: Database schema information
|
| 136 |
-
plan_text: Query execution plan
|
| 137 |
-
clarify_answers: Optional additional context_engineering
|
| 138 |
-
|
| 139 |
-
Returns:
|
| 140 |
-
Tuple of (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 141 |
-
"""
|
| 142 |
-
system_prompt = """You are an expert SQL query generator for SQLite databases.
|
| 143 |
-
You must follow these STRICT rules to generate clean, simple SQL:
|
| 144 |
|
| 145 |
CRITICAL RULES:
|
| 146 |
1. Write the SIMPLEST possible SQL that answers the question
|
|
@@ -173,6 +260,9 @@ Database Schema:
|
|
| 173 |
Query Plan:
|
| 174 |
{plan_text}
|
| 175 |
|
|
|
|
|
|
|
|
|
|
| 176 |
Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
|
| 177 |
|
| 178 |
Example of what we want:
|
|
@@ -199,7 +289,6 @@ Now generate the SQL for the given question:"""
|
|
| 199 |
content = text.strip() if text else ""
|
| 200 |
usage = completion.usage
|
| 201 |
|
| 202 |
-
# Parse JSON response
|
| 203 |
try:
|
| 204 |
parsed = json.loads(content)
|
| 205 |
except json.JSONDecodeError:
|
|
@@ -208,21 +297,21 @@ Now generate the SQL for the given question:"""
|
|
| 208 |
if start != -1 and end != -1:
|
| 209 |
try:
|
| 210 |
parsed = json.loads(content[start : end + 1])
|
| 211 |
-
except Exception:
|
| 212 |
-
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 213 |
else:
|
| 214 |
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 215 |
|
| 216 |
-
sql = (parsed.get("sql") or "").strip()
|
| 217 |
-
rationale = parsed.get("rationale") or ""
|
| 218 |
|
| 219 |
-
# Post-process SQL to ensure simplicity
|
| 220 |
sql = self._simplify_sql(sql)
|
| 221 |
-
|
| 222 |
if not sql:
|
| 223 |
raise ValueError("LLM returned empty 'sql'")
|
| 224 |
|
|
|
|
| 225 |
sql_length = len(sql)
|
|
|
|
| 226 |
if usage:
|
| 227 |
prompt_tokens = usage.prompt_tokens
|
| 228 |
completion_tokens = usage.completion_tokens
|
|
@@ -233,35 +322,33 @@ Now generate the SQL for the given question:"""
|
|
| 233 |
"completion_tokens": completion_tokens,
|
| 234 |
"cost_usd": cost,
|
| 235 |
"sql_length": sql_length,
|
|
|
|
| 236 |
}
|
| 237 |
return (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
| 247 |
|
| 248 |
def _simplify_sql(self, sql: str) -> str:
|
| 249 |
"""Post-process SQL to remove common unnecessary additions."""
|
| 250 |
if not sql:
|
| 251 |
return sql
|
| 252 |
|
| 253 |
-
# Remove trailing semicolon
|
| 254 |
sql = sql.rstrip(";")
|
| 255 |
|
| 256 |
-
# Remove unnecessary table prefixes in simple queries
|
| 257 |
-
# e.g., "singer.name" -> "name" when there's only one table
|
| 258 |
if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
|
| 259 |
match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
|
| 260 |
if match:
|
| 261 |
table = match.group(1)
|
| 262 |
sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
|
| 263 |
|
| 264 |
-
# Remove unnecessary DISTINCT in COUNT(*)
|
| 265 |
sql = re.sub(
|
| 266 |
r"count\s*\(\s*distinct\s+\*\s*\)",
|
| 267 |
"count(*)",
|
|
@@ -269,7 +356,6 @@ Now generate the SQL for the given question:"""
|
|
| 269 |
flags=re.IGNORECASE,
|
| 270 |
)
|
| 271 |
|
| 272 |
-
# Remove big default LIMITs that weren't requested
|
| 273 |
sql = re.sub(
|
| 274 |
r"\s+limit\s+(100|1000|10000)\b",
|
| 275 |
"",
|
|
@@ -286,16 +372,7 @@ Now generate the SQL for the given question:"""
|
|
| 286 |
error_msg: str,
|
| 287 |
schema_preview: str,
|
| 288 |
) -> Tuple[str, int, int, float]:
|
| 289 |
-
"""
|
| 290 |
-
|
| 291 |
-
Args:
|
| 292 |
-
sql: Broken SQL query
|
| 293 |
-
error_msg: Error message from execution
|
| 294 |
-
schema_preview: Database schema information
|
| 295 |
-
|
| 296 |
-
Returns:
|
| 297 |
-
Tuple of (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 298 |
-
"""
|
| 299 |
system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
|
| 300 |
|
| 301 |
IMPORTANT RULES:
|
|
@@ -332,7 +409,6 @@ Return the corrected SQL (keep it simple):"""
|
|
| 332 |
text = completion.choices[0].message.content
|
| 333 |
fixed_sql = text.strip() if text else ""
|
| 334 |
|
| 335 |
-
# Clean up accidental code fences
|
| 336 |
if fixed_sql.startswith("```sql"):
|
| 337 |
fixed_sql = fixed_sql[6:]
|
| 338 |
if fixed_sql.startswith("```"):
|
|
@@ -344,7 +420,6 @@ Return the corrected SQL (keep it simple):"""
|
|
| 344 |
fixed_sql = self._simplify_sql(fixed_sql)
|
| 345 |
|
| 346 |
usage = completion.usage
|
| 347 |
-
|
| 348 |
if usage:
|
| 349 |
prompt_tokens = usage.prompt_tokens
|
| 350 |
completion_tokens = usage.completion_tokens
|
|
@@ -357,88 +432,12 @@ Return the corrected SQL (keep it simple):"""
|
|
| 357 |
"sql_length": len(fixed_sql),
|
| 358 |
}
|
| 359 |
return (fixed_sql, prompt_tokens, completion_tokens, cost)
|
| 360 |
-
else:
|
| 361 |
-
self._last_usage = {
|
| 362 |
-
"kind": "repair",
|
| 363 |
-
"prompt_tokens": 0,
|
| 364 |
-
"completion_tokens": 0,
|
| 365 |
-
"cost_usd": 0.0,
|
| 366 |
-
"sql_length": len(fixed_sql),
|
| 367 |
-
}
|
| 368 |
-
return (fixed_sql, 0, 0, 0.0)
|
| 369 |
-
|
| 370 |
-
def _estimate_cost(self, usage: Any) -> float:
|
| 371 |
-
"""Estimate cost based on token usage.
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
if not usage:
|
| 380 |
-
return 0.0
|
| 381 |
-
|
| 382 |
-
# Pricing per 1K tokens (adjust based on model)
|
| 383 |
-
pricing = {
|
| 384 |
-
"gpt-4": {"input": 0.03, "output": 0.06},
|
| 385 |
-
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
| 386 |
-
"gpt-4o": {"input": 0.005, "output": 0.015},
|
| 387 |
-
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
| 388 |
-
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
|
| 389 |
}
|
| 390 |
-
|
| 391 |
-
model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
|
| 392 |
-
|
| 393 |
-
input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
|
| 394 |
-
output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
|
| 395 |
-
|
| 396 |
-
return input_cost + output_cost
|
| 397 |
-
|
| 398 |
-
def clarify(
|
| 399 |
-
self,
|
| 400 |
-
*,
|
| 401 |
-
user_query: str,
|
| 402 |
-
schema_preview: str,
|
| 403 |
-
questions: List[str],
|
| 404 |
-
) -> Tuple[str, int, int, float]:
|
| 405 |
-
"""Clarify ambiguities in the user query.
|
| 406 |
-
|
| 407 |
-
Args:
|
| 408 |
-
user_query: The user's natural language question
|
| 409 |
-
schema_preview: Database schema information
|
| 410 |
-
questions: List of clarification questions
|
| 411 |
-
|
| 412 |
-
Returns:
|
| 413 |
-
Tuple of (answers, prompt_tokens, completion_tokens, cost)
|
| 414 |
-
"""
|
| 415 |
-
system_prompt = """You are a helpful assistant that clarifies SQL query requirements.
|
| 416 |
-
Answer the questions clearly and concisely based on the user's query and database schema."""
|
| 417 |
-
|
| 418 |
-
user_prompt = f"""User Query: {user_query}
|
| 419 |
-
|
| 420 |
-
Database Schema:
|
| 421 |
-
{schema_preview}
|
| 422 |
-
|
| 423 |
-
Please answer these clarification questions:
|
| 424 |
-
{chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
|
| 425 |
-
|
| 426 |
-
completion = self._create_chat_completion(
|
| 427 |
-
model=self.model,
|
| 428 |
-
messages=[
|
| 429 |
-
{"role": "system", "content": system_prompt},
|
| 430 |
-
{"role": "user", "content": user_prompt},
|
| 431 |
-
],
|
| 432 |
-
temperature=0.3,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
answers = completion.choices[0].message.content or ""
|
| 436 |
-
usage = completion.usage
|
| 437 |
-
|
| 438 |
-
if usage:
|
| 439 |
-
prompt_tokens = usage.prompt_tokens
|
| 440 |
-
completion_tokens = usage.completion_tokens
|
| 441 |
-
cost = self._estimate_cost(usage)
|
| 442 |
-
return (answers, prompt_tokens, completion_tokens, cost)
|
| 443 |
-
else:
|
| 444 |
-
return (answers, 0, 0, 0.0)
|
|
|
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
|
| 8 |
from adapters.llm.base import LLMProvider
|
| 9 |
from openai import OpenAI
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
class OpenAIProvider(LLMProvider):
|
| 38 |
+
"""OpenAI LLM provider implementation.
|
| 39 |
|
| 40 |
+
Goals for this implementation:
|
| 41 |
+
- Keep prompts and behavior as close as possible to the current repo version.
|
| 42 |
+
- Align method signatures + return shapes with the updated LLMProvider Protocol.
|
| 43 |
+
- Provide a lightweight `used_tables` signal for observability/drift checks.
|
| 44 |
+
"""
|
| 45 |
|
| 46 |
+
PROVIDER_ID = "openai"
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def __init__(self) -> None:
|
| 49 |
"""Initialize OpenAI client with config from environment."""
|
|
|
|
| 52 |
os.environ["OPENAI_BASE_URL"] = base_url
|
| 53 |
self.client = OpenAI(timeout=120.0)
|
| 54 |
self.model = model
|
|
|
|
| 55 |
self._last_usage: dict[str, Any] = {}
|
| 56 |
|
| 57 |
+
def get_last_usage(self) -> dict[str, Any]:
|
| 58 |
+
"""Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
|
| 59 |
+
return dict(self._last_usage)
|
| 60 |
+
|
| 61 |
+
def _create_chat_completion(self, **kwargs):
|
| 62 |
+
"""OpenAI SDK seam for stable unit testing."""
|
| 63 |
+
return self.client.chat.completions.create(**kwargs)
|
| 64 |
+
|
| 65 |
+
# ---------------------------------------------------------------------
|
| 66 |
+
# Table extraction helpers (best-effort; no heavy parsing).
|
| 67 |
+
# ---------------------------------------------------------------------
|
| 68 |
+
def _extract_schema_tables(self, schema_preview: str) -> List[str]:
|
| 69 |
+
"""Extract likely table names from the schema preview string."""
|
| 70 |
+
if not schema_preview:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
tables: List[str] = []
|
| 74 |
+
|
| 75 |
+
for m in re.finditer(
|
| 76 |
+
r"(?im)^\s*(?:-\s*)?table\s*[: ]\s*([A-Za-z_][A-Za-z0-9_]*)\b",
|
| 77 |
+
schema_preview,
|
| 78 |
+
):
|
| 79 |
+
tables.append(m.group(1))
|
| 80 |
+
|
| 81 |
+
for m in re.finditer(
|
| 82 |
+
r"(?im)^\s*create\s+table\s+`?([A-Za-z_][A-Za-z0-9_]*)`?\b", schema_preview
|
| 83 |
+
):
|
| 84 |
+
tables.append(m.group(1))
|
| 85 |
+
|
| 86 |
+
seen = set()
|
| 87 |
+
uniq: List[str] = []
|
| 88 |
+
for t in tables:
|
| 89 |
+
if t not in seen:
|
| 90 |
+
uniq.append(t)
|
| 91 |
+
seen.add(t)
|
| 92 |
+
return uniq
|
| 93 |
+
|
| 94 |
+
def _extract_tables_from_sql(self, sql: str) -> List[str]:
|
| 95 |
+
"""Very lightweight table extraction from FROM/JOIN clauses."""
|
| 96 |
+
if not sql:
|
| 97 |
+
return []
|
| 98 |
+
pairs = re.findall(
|
| 99 |
+
r"\bfrom\s+([A-Za-z_][A-Za-z0-9_]*)|\bjoin\s+([A-Za-z_][A-Za-z0-9_]*)",
|
| 100 |
+
sql,
|
| 101 |
+
flags=re.IGNORECASE,
|
| 102 |
+
)
|
| 103 |
+
out: List[str] = []
|
| 104 |
+
for t1, t2 in pairs:
|
| 105 |
+
if t1:
|
| 106 |
+
out.append(t1)
|
| 107 |
+
if t2:
|
| 108 |
+
out.append(t2)
|
| 109 |
+
|
| 110 |
+
seen = set()
|
| 111 |
+
uniq: List[str] = []
|
| 112 |
+
for t in out:
|
| 113 |
+
if t not in seen:
|
| 114 |
+
uniq.append(t)
|
| 115 |
+
seen.add(t)
|
| 116 |
+
return uniq
|
| 117 |
+
|
| 118 |
+
def _extract_used_tables_from_plan(
|
| 119 |
+
self, plan_text: str, schema_preview: str
|
| 120 |
+
) -> List[str]:
|
| 121 |
+
"""Best-effort used table list from plan text by intersecting with schema table names."""
|
| 122 |
+
candidates = self._extract_schema_tables(schema_preview)
|
| 123 |
+
if not candidates or not plan_text:
|
| 124 |
+
return []
|
| 125 |
+
used: List[str] = []
|
| 126 |
+
for t in candidates:
|
| 127 |
+
if re.search(rf"\b{re.escape(t)}\b", plan_text, flags=re.IGNORECASE):
|
| 128 |
+
used.append(t)
|
| 129 |
+
return used
|
| 130 |
+
|
| 131 |
+
# ---------------------------------------------------------------------
|
| 132 |
+
# Cost estimation
|
| 133 |
+
# ---------------------------------------------------------------------
|
| 134 |
+
def _estimate_cost(self, usage: Any) -> float:
|
| 135 |
+
"""Estimate cost based on token usage."""
|
| 136 |
+
if not usage:
|
| 137 |
+
return 0.0
|
| 138 |
+
|
| 139 |
+
pricing = {
|
| 140 |
+
"gpt-4": {"input": 0.03, "output": 0.06},
|
| 141 |
+
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
| 142 |
+
"gpt-4o": {"input": 0.005, "output": 0.015},
|
| 143 |
+
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
| 144 |
+
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
|
| 145 |
+
}
|
| 146 |
|
| 147 |
+
model_pricing = pricing.get(self.model, pricing["gpt-4o-mini"])
|
| 148 |
+
input_cost = (usage.prompt_tokens / 1000) * model_pricing["input"]
|
| 149 |
+
output_cost = (usage.completion_tokens / 1000) * model_pricing["output"]
|
| 150 |
+
return input_cost + output_cost
|
| 151 |
|
| 152 |
+
# ---------------------------------------------------------------------
|
| 153 |
+
# LLMProvider API
|
| 154 |
+
# ---------------------------------------------------------------------
|
| 155 |
+
def plan(
|
| 156 |
+
self,
|
| 157 |
+
*,
|
| 158 |
+
user_query: str,
|
| 159 |
+
schema_preview: str,
|
| 160 |
+
constraints: List[str] | None = None,
|
| 161 |
+
) -> Tuple[str, List[str], int, int, float]:
|
| 162 |
+
"""Return (plan_text, used_tables, token_in, token_out, cost_usd)."""
|
| 163 |
system_prompt = """You are a SQL query planning expert. Analyze the user's question and database schema to create a clear execution plan.
|
| 164 |
|
| 165 |
Your plan should:
|
|
|
|
| 177 |
Database Schema:
|
| 178 |
{schema_preview}
|
| 179 |
|
| 180 |
+
Constraints:
|
| 181 |
+
{constraints or []}
|
| 182 |
+
|
| 183 |
Create a step-by-step plan to answer this question with SQL."""
|
| 184 |
|
| 185 |
completion = self._create_chat_completion(
|
|
|
|
| 194 |
msg = completion.choices[0].message.content or ""
|
| 195 |
usage = completion.usage
|
| 196 |
|
| 197 |
+
plan_text = msg.strip()
|
| 198 |
+
used_tables = self._extract_used_tables_from_plan(plan_text, schema_preview)
|
| 199 |
+
|
| 200 |
if usage:
|
| 201 |
prompt_tokens = usage.prompt_tokens
|
| 202 |
completion_tokens = usage.completion_tokens
|
|
|
|
| 207 |
"completion_tokens": completion_tokens,
|
| 208 |
"cost_usd": cost,
|
| 209 |
}
|
| 210 |
+
return (plan_text, used_tables, prompt_tokens, completion_tokens, cost)
|
| 211 |
+
|
| 212 |
+
self._last_usage = {
|
| 213 |
+
"kind": "plan",
|
| 214 |
+
"prompt_tokens": 0,
|
| 215 |
+
"completion_tokens": 0,
|
| 216 |
+
"cost_usd": 0.0,
|
| 217 |
+
}
|
| 218 |
+
return (plan_text, used_tables, 0, 0, 0.0)
|
| 219 |
|
| 220 |
def generate_sql(
|
| 221 |
self,
|
|
|
|
| 223 |
user_query: str,
|
| 224 |
schema_preview: str,
|
| 225 |
plan_text: str,
|
| 226 |
+
constraints: List[str] | None = None,
|
| 227 |
+
clarify_answers: Dict[str, Any] | None = None,
|
| 228 |
) -> Tuple[str, str, int, int, float]:
|
| 229 |
+
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
| 230 |
+
system_prompt = """You are an expert SQL generator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
CRITICAL RULES:
|
| 233 |
1. Write the SIMPLEST possible SQL that answers the question
|
|
|
|
| 260 |
Query Plan:
|
| 261 |
{plan_text}
|
| 262 |
|
| 263 |
+
Constraints:
|
| 264 |
+
{constraints or []}
|
| 265 |
+
|
| 266 |
Remember: Generate the SIMPLEST possible SQL. Avoid table prefixes, aliases, and unnecessary clauses.
|
| 267 |
|
| 268 |
Example of what we want:
|
|
|
|
| 289 |
content = text.strip() if text else ""
|
| 290 |
usage = completion.usage
|
| 291 |
|
|
|
|
| 292 |
try:
|
| 293 |
parsed = json.loads(content)
|
| 294 |
except json.JSONDecodeError:
|
|
|
|
| 297 |
if start != -1 and end != -1:
|
| 298 |
try:
|
| 299 |
parsed = json.loads(content[start : end + 1])
|
| 300 |
+
except Exception as e:
|
| 301 |
+
raise ValueError(f"Invalid LLM JSON output: {content[:200]}") from e
|
| 302 |
else:
|
| 303 |
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 304 |
|
| 305 |
+
sql = str(parsed.get("sql") or "").strip()
|
| 306 |
+
rationale = str(parsed.get("rationale") or "")
|
| 307 |
|
|
|
|
| 308 |
sql = self._simplify_sql(sql)
|
|
|
|
| 309 |
if not sql:
|
| 310 |
raise ValueError("LLM returned empty 'sql'")
|
| 311 |
|
| 312 |
+
used_tables = self._extract_tables_from_sql(sql)
|
| 313 |
sql_length = len(sql)
|
| 314 |
+
|
| 315 |
if usage:
|
| 316 |
prompt_tokens = usage.prompt_tokens
|
| 317 |
completion_tokens = usage.completion_tokens
|
|
|
|
| 322 |
"completion_tokens": completion_tokens,
|
| 323 |
"cost_usd": cost,
|
| 324 |
"sql_length": sql_length,
|
| 325 |
+
"used_tables": used_tables,
|
| 326 |
}
|
| 327 |
return (sql, rationale, prompt_tokens, completion_tokens, cost)
|
| 328 |
+
|
| 329 |
+
self._last_usage = {
|
| 330 |
+
"kind": "generate",
|
| 331 |
+
"prompt_tokens": 0,
|
| 332 |
+
"completion_tokens": 0,
|
| 333 |
+
"cost_usd": 0.0,
|
| 334 |
+
"sql_length": sql_length,
|
| 335 |
+
"used_tables": used_tables,
|
| 336 |
+
}
|
| 337 |
+
return (sql, rationale, 0, 0, 0.0)
|
| 338 |
|
| 339 |
def _simplify_sql(self, sql: str) -> str:
|
| 340 |
"""Post-process SQL to remove common unnecessary additions."""
|
| 341 |
if not sql:
|
| 342 |
return sql
|
| 343 |
|
|
|
|
| 344 |
sql = sql.rstrip(";")
|
| 345 |
|
|
|
|
|
|
|
| 346 |
if sql.lower().count(" from ") == 1 and " join " not in sql.lower():
|
| 347 |
match = re.search(r"\bfrom\s+(\w+)", sql, re.IGNORECASE)
|
| 348 |
if match:
|
| 349 |
table = match.group(1)
|
| 350 |
sql = re.sub(rf"\b{table}\.(\w+)\b", r"\1", sql)
|
| 351 |
|
|
|
|
| 352 |
sql = re.sub(
|
| 353 |
r"count\s*\(\s*distinct\s+\*\s*\)",
|
| 354 |
"count(*)",
|
|
|
|
| 356 |
flags=re.IGNORECASE,
|
| 357 |
)
|
| 358 |
|
|
|
|
| 359 |
sql = re.sub(
|
| 360 |
r"\s+limit\s+(100|1000|10000)\b",
|
| 361 |
"",
|
|
|
|
| 372 |
error_msg: str,
|
| 373 |
schema_preview: str,
|
| 374 |
) -> Tuple[str, int, int, float]:
|
| 375 |
+
"""Return (patched_sql, token_in, token_out, cost_usd)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
system_prompt = """You are a SQL repair expert. Fix the given SQL query to resolve the error.
|
| 377 |
|
| 378 |
IMPORTANT RULES:
|
|
|
|
| 409 |
text = completion.choices[0].message.content
|
| 410 |
fixed_sql = text.strip() if text else ""
|
| 411 |
|
|
|
|
| 412 |
if fixed_sql.startswith("```sql"):
|
| 413 |
fixed_sql = fixed_sql[6:]
|
| 414 |
if fixed_sql.startswith("```"):
|
|
|
|
| 420 |
fixed_sql = self._simplify_sql(fixed_sql)
|
| 421 |
|
| 422 |
usage = completion.usage
|
|
|
|
| 423 |
if usage:
|
| 424 |
prompt_tokens = usage.prompt_tokens
|
| 425 |
completion_tokens = usage.completion_tokens
|
|
|
|
| 432 |
"sql_length": len(fixed_sql),
|
| 433 |
}
|
| 434 |
return (fixed_sql, prompt_tokens, completion_tokens, cost)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
+
self._last_usage = {
|
| 437 |
+
"kind": "repair",
|
| 438 |
+
"prompt_tokens": 0,
|
| 439 |
+
"completion_tokens": 0,
|
| 440 |
+
"cost_usd": 0.0,
|
| 441 |
+
"sql_length": len(fixed_sql),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
}
|
| 443 |
+
return (fixed_sql, 0, 0, 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nl2sql/errors/codes.py
CHANGED
|
@@ -14,6 +14,7 @@ class ErrorCode(str, Enum):
|
|
| 14 |
# --- Executor / DB ---
|
| 15 |
DB_LOCKED = "DB_LOCKED"
|
| 16 |
DB_TIMEOUT = "DB_TIMEOUT"
|
|
|
|
| 17 |
|
| 18 |
# --- LLM ---
|
| 19 |
LLM_TIMEOUT = "LLM_TIMEOUT"
|
|
|
|
| 14 |
# --- Executor / DB ---
|
| 15 |
DB_LOCKED = "DB_LOCKED"
|
| 16 |
DB_TIMEOUT = "DB_TIMEOUT"
|
| 17 |
+
LLM_FAILURE = "LLM_FAILURE"
|
| 18 |
|
| 19 |
# --- LLM ---
|
| 20 |
LLM_TIMEOUT = "LLM_TIMEOUT"
|
nl2sql/generator.py
CHANGED
|
@@ -20,7 +20,9 @@ class Generator:
|
|
| 20 |
user_query: str,
|
| 21 |
schema_preview: str,
|
| 22 |
plan_text: str,
|
|
|
|
| 23 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
|
|
|
| 24 |
) -> StageResult:
|
| 25 |
t0 = time.perf_counter()
|
| 26 |
|
|
@@ -29,10 +31,11 @@ class Generator:
|
|
| 29 |
user_query=user_query,
|
| 30 |
schema_preview=schema_preview,
|
| 31 |
plan_text=plan_text,
|
|
|
|
| 32 |
clarify_answers=clarify_answers or {},
|
| 33 |
)
|
| 34 |
except Exception as e:
|
| 35 |
-
# Provider/transport errors or unexpected runtime
|
| 36 |
return StageResult(
|
| 37 |
ok=False,
|
| 38 |
error=[f"Generator failed: {e}"],
|
|
@@ -40,18 +43,22 @@ class Generator:
|
|
| 40 |
trace=None,
|
| 41 |
)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
if not isinstance(res, tuple) or len(res) != 5:
|
| 45 |
return StageResult(
|
| 46 |
ok=False,
|
| 47 |
error=[
|
| 48 |
-
"Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
|
| 49 |
],
|
| 50 |
error_code=ErrorCode.LLM_BAD_OUTPUT,
|
| 51 |
trace=None,
|
| 52 |
)
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# Type/shape checks
|
| 57 |
if not isinstance(sql, str) or not sql.strip():
|
|
@@ -73,18 +80,20 @@ class Generator:
|
|
| 73 |
|
| 74 |
# Normalize rationale to a string
|
| 75 |
rationale = rationale or ""
|
|
|
|
| 76 |
trace = StageTrace(
|
| 77 |
stage=self.name,
|
|
|
|
| 78 |
duration_ms=(time.perf_counter() - t0) * 1000.0,
|
| 79 |
token_in=t_in,
|
| 80 |
token_out=t_out,
|
| 81 |
cost_usd=cost,
|
| 82 |
-
notes={"rationale_len": len(rationale)},
|
| 83 |
)
|
| 84 |
|
| 85 |
return StageResult(
|
| 86 |
ok=True,
|
| 87 |
-
data={"sql": sql, "rationale": rationale},
|
| 88 |
trace=trace,
|
| 89 |
error_code=None,
|
| 90 |
retryable=None,
|
|
|
|
| 20 |
user_query: str,
|
| 21 |
schema_preview: str,
|
| 22 |
plan_text: str,
|
| 23 |
+
constraints: Optional[list[str]] = None,
|
| 24 |
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 25 |
+
traces: Optional[list[dict]] = None,
|
| 26 |
) -> StageResult:
|
| 27 |
t0 = time.perf_counter()
|
| 28 |
|
|
|
|
| 31 |
user_query=user_query,
|
| 32 |
schema_preview=schema_preview,
|
| 33 |
plan_text=plan_text,
|
| 34 |
+
constraints=constraints or [],
|
| 35 |
clarify_answers=clarify_answers or {},
|
| 36 |
)
|
| 37 |
except Exception as e:
|
| 38 |
+
# Provider/transport errors or unexpected runtime exceptions.
|
| 39 |
return StageResult(
|
| 40 |
ok=False,
|
| 41 |
error=[f"Generator failed: {e}"],
|
|
|
|
| 43 |
trace=None,
|
| 44 |
)
|
| 45 |
|
| 46 |
+
if not isinstance(res, tuple) or len(res) not in (5, 6):
|
|
|
|
| 47 |
return StageResult(
|
| 48 |
ok=False,
|
| 49 |
error=[
|
| 50 |
+
"Generator contract violation: expected 5/6-tuple (sql, rationale, [used_tables], t_in, t_out, cost)"
|
| 51 |
],
|
| 52 |
error_code=ErrorCode.LLM_BAD_OUTPUT,
|
| 53 |
trace=None,
|
| 54 |
)
|
| 55 |
|
| 56 |
+
used_tables: list[str] = []
|
| 57 |
+
|
| 58 |
+
if len(res) == 6:
|
| 59 |
+
sql, rationale, used_tables, t_in, t_out, cost = res
|
| 60 |
+
else:
|
| 61 |
+
sql, rationale, t_in, t_out, cost = res
|
| 62 |
|
| 63 |
# Type/shape checks
|
| 64 |
if not isinstance(sql, str) or not sql.strip():
|
|
|
|
| 80 |
|
| 81 |
# Normalize rationale to a string
|
| 82 |
rationale = rationale or ""
|
| 83 |
+
|
| 84 |
trace = StageTrace(
|
| 85 |
stage=self.name,
|
| 86 |
+
summary="Generated SQL",
|
| 87 |
duration_ms=(time.perf_counter() - t0) * 1000.0,
|
| 88 |
token_in=t_in,
|
| 89 |
token_out=t_out,
|
| 90 |
cost_usd=cost,
|
| 91 |
+
notes={"rationale_len": len(rationale), "used_tables": used_tables},
|
| 92 |
)
|
| 93 |
|
| 94 |
return StageResult(
|
| 95 |
ok=True,
|
| 96 |
+
data={"sql": sql, "rationale": rationale, "used_tables": used_tables},
|
| 97 |
trace=trace,
|
| 98 |
error_code=None,
|
| 99 |
retryable=None,
|
nl2sql/pipeline.py
CHANGED
|
@@ -276,6 +276,17 @@ class Pipeline:
|
|
| 276 |
details: List[str] = []
|
| 277 |
exec_result: Dict[str, Any] = {}
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 280 |
traces.append(
|
| 281 |
self._mk_trace(
|
|
@@ -411,6 +422,33 @@ class Pipeline:
|
|
| 411 |
sql = (r_gen.data or {}).get("sql")
|
| 412 |
rationale = (r_gen.data or {}).get("rationale")
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
# Guard: empty SQL
|
| 415 |
if not sql or not str(sql).strip():
|
| 416 |
pipeline_runs_total.labels(status="error").inc()
|
|
@@ -485,44 +523,39 @@ class Pipeline:
|
|
| 485 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 486 |
exec_result = dict(r_exec.data)
|
| 487 |
|
| 488 |
-
# --- 6) verifier (
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
self.
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
| 506 |
|
| 507 |
-
|
| 508 |
-
if r_ver.data and isinstance(r_ver.data, dict):
|
| 509 |
-
repaired_sql = r_ver.data.get("sql")
|
| 510 |
-
if repaired_sql:
|
| 511 |
-
sql = repaired_sql
|
| 512 |
|
| 513 |
# Verified flag
|
| 514 |
-
verified = (
|
| 515 |
-
bool(
|
| 516 |
-
r_ver.data
|
| 517 |
-
and isinstance(r_ver.data, dict)
|
| 518 |
-
and r_ver.data.get("verified")
|
| 519 |
-
)
|
| 520 |
-
or r_ver.ok
|
| 521 |
-
)
|
| 522 |
|
| 523 |
# consume repaired SQL from verifier if any
|
| 524 |
-
|
| 525 |
-
|
|
|
|
| 526 |
|
| 527 |
# --- 7) repair loop (if not verified) ---
|
| 528 |
if not verified:
|
|
@@ -534,11 +567,12 @@ class Pipeline:
|
|
| 534 |
self.repair.run,
|
| 535 |
sql=sql,
|
| 536 |
error_msg="; ".join(details or ["unknown"]),
|
| 537 |
-
schema_preview=
|
| 538 |
)
|
| 539 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 540 |
stage_duration_ms.labels("repair").observe(dt)
|
| 541 |
traces.extend(self._trace_list(r_fix))
|
|
|
|
| 542 |
if not getattr(r_fix, "trace", None):
|
| 543 |
_fallback_trace("repair", dt, r_fix.ok)
|
| 544 |
if not r_fix.ok:
|
|
@@ -553,6 +587,7 @@ class Pipeline:
|
|
| 553 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 554 |
stage_duration_ms.labels("safety").observe(dt2)
|
| 555 |
traces.extend(self._trace_list(r_safe2))
|
|
|
|
| 556 |
if not getattr(r_safe2, "trace", None):
|
| 557 |
_fallback_trace("safety", dt2, r_safe2.ok)
|
| 558 |
if not r_safe2.ok:
|
|
@@ -567,6 +602,7 @@ class Pipeline:
|
|
| 567 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 568 |
stage_duration_ms.labels("executor").observe(dt2)
|
| 569 |
traces.extend(self._trace_list(r_exec2))
|
|
|
|
| 570 |
if not getattr(r_exec2, "trace", None):
|
| 571 |
_fallback_trace("executor", dt2, r_exec2.ok)
|
| 572 |
if not r_exec2.ok:
|
|
@@ -586,11 +622,10 @@ class Pipeline:
|
|
| 586 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 587 |
stage_duration_ms.labels("verifier").observe(dt2)
|
| 588 |
traces.extend(self._trace_list(r_ver2))
|
|
|
|
| 589 |
if not getattr(r_ver2, "trace", None):
|
| 590 |
_fallback_trace("verifier", dt2, r_ver2.ok)
|
| 591 |
-
verified = (
|
| 592 |
-
bool(r_ver2.data and r_ver2.data.get("verified")) or r_ver2.ok
|
| 593 |
-
)
|
| 594 |
if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
|
| 595 |
sql = r_ver2.data["sql"]
|
| 596 |
if verified:
|
|
|
|
| 276 |
details: List[str] = []
|
| 277 |
exec_result: Dict[str, Any] = {}
|
| 278 |
|
| 279 |
+
def _tag_last_trace_attempt(stage_name: str, attempt: int) -> None:
|
| 280 |
+
# Attach attempt metadata to the most recent trace entry for this stage.
|
| 281 |
+
for t in reversed(traces):
|
| 282 |
+
if t.get("stage") == stage_name:
|
| 283 |
+
notes = t.get("notes") or {}
|
| 284 |
+
if not isinstance(notes, dict):
|
| 285 |
+
notes = {}
|
| 286 |
+
notes["attempt"] = attempt
|
| 287 |
+
t["notes"] = notes
|
| 288 |
+
return
|
| 289 |
+
|
| 290 |
def _fallback_trace(stage_name: str, dt_ms: float, ok: bool) -> None:
|
| 291 |
traces.append(
|
| 292 |
self._mk_trace(
|
|
|
|
| 422 |
sql = (r_gen.data or {}).get("sql")
|
| 423 |
rationale = (r_gen.data or {}).get("rationale")
|
| 424 |
|
| 425 |
+
# --- schema drift signal (planner vs generator table usage)
|
| 426 |
+
planner_used_tables = (
|
| 427 |
+
(r_plan.data or {}).get("used_tables")
|
| 428 |
+
or (r_plan.data or {}).get("tables")
|
| 429 |
+
or []
|
| 430 |
+
)
|
| 431 |
+
generator_used_tables = (
|
| 432 |
+
(r_gen.data or {}).get("used_tables")
|
| 433 |
+
or (r_gen.data or {}).get("tables")
|
| 434 |
+
or []
|
| 435 |
+
)
|
| 436 |
+
planner_set = set(planner_used_tables)
|
| 437 |
+
generator_set = set(generator_used_tables)
|
| 438 |
+
schema_drift = bool(generator_set - planner_set)
|
| 439 |
+
traces.append(
|
| 440 |
+
self._mk_trace(
|
| 441 |
+
stage="schema_drift_check",
|
| 442 |
+
duration_ms=0.0,
|
| 443 |
+
summary="compare planner vs generator table usage",
|
| 444 |
+
notes={
|
| 445 |
+
"planner_used_tables": sorted(planner_set),
|
| 446 |
+
"generator_used_tables": sorted(generator_set),
|
| 447 |
+
"schema_drift": schema_drift,
|
| 448 |
+
},
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
# Guard: empty SQL
|
| 453 |
if not sql or not str(sql).strip():
|
| 454 |
pipeline_runs_total.labels(status="error").inc()
|
|
|
|
| 523 |
if r_exec.ok and isinstance(r_exec.data, dict):
|
| 524 |
exec_result = dict(r_exec.data)
|
| 525 |
|
| 526 |
+
# --- 6) verifier (only if execution succeeded) ---
|
| 527 |
+
r_ver = None
|
| 528 |
+
if r_exec.ok:
|
| 529 |
+
t0 = time.perf_counter()
|
| 530 |
+
r_ver = self._run_with_repair(
|
| 531 |
+
"verifier",
|
| 532 |
+
self._call_verifier,
|
| 533 |
+
repair_input_builder=self._sql_repair_input_builder,
|
| 534 |
+
max_attempts=1,
|
| 535 |
+
sql=sql,
|
| 536 |
+
exec_result=(r_exec.data or {}),
|
| 537 |
+
traces=traces,
|
| 538 |
+
)
|
| 539 |
+
dt = (time.perf_counter() - t0) * 1000.0
|
| 540 |
+
stage_duration_ms.labels("verifier").observe(dt)
|
| 541 |
+
|
| 542 |
+
# Traces
|
| 543 |
|
| 544 |
+
# If verifier (or its repair) produced a new SQL, consume it
|
| 545 |
+
if r_ver.data and isinstance(r_ver.data, dict):
|
| 546 |
+
repaired_sql = r_ver.data.get("sql")
|
| 547 |
+
if repaired_sql:
|
| 548 |
+
sql = repaired_sql
|
| 549 |
|
| 550 |
+
data = r_ver.data if (r_ver and isinstance(r_ver.data, dict)) else {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
# Verified flag
|
| 553 |
+
verified = bool(data.get("verified") is True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
# consume repaired SQL from verifier if any
|
| 556 |
+
repaired_sql = data.get("sql")
|
| 557 |
+
if repaired_sql:
|
| 558 |
+
sql = repaired_sql
|
| 559 |
|
| 560 |
# --- 7) repair loop (if not verified) ---
|
| 561 |
if not verified:
|
|
|
|
| 567 |
self.repair.run,
|
| 568 |
sql=sql,
|
| 569 |
error_msg="; ".join(details or ["unknown"]),
|
| 570 |
+
schema_preview=schema_for_llm,
|
| 571 |
)
|
| 572 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 573 |
stage_duration_ms.labels("repair").observe(dt)
|
| 574 |
traces.extend(self._trace_list(r_fix))
|
| 575 |
+
_tag_last_trace_attempt("repair", _attempt)
|
| 576 |
if not getattr(r_fix, "trace", None):
|
| 577 |
_fallback_trace("repair", dt, r_fix.ok)
|
| 578 |
if not r_fix.ok:
|
|
|
|
| 587 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 588 |
stage_duration_ms.labels("safety").observe(dt2)
|
| 589 |
traces.extend(self._trace_list(r_safe2))
|
| 590 |
+
_tag_last_trace_attempt("safety", _attempt)
|
| 591 |
if not getattr(r_safe2, "trace", None):
|
| 592 |
_fallback_trace("safety", dt2, r_safe2.ok)
|
| 593 |
if not r_safe2.ok:
|
|
|
|
| 602 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 603 |
stage_duration_ms.labels("executor").observe(dt2)
|
| 604 |
traces.extend(self._trace_list(r_exec2))
|
| 605 |
+
_tag_last_trace_attempt("executor", _attempt)
|
| 606 |
if not getattr(r_exec2, "trace", None):
|
| 607 |
_fallback_trace("executor", dt2, r_exec2.ok)
|
| 608 |
if not r_exec2.ok:
|
|
|
|
| 622 |
dt2 = (time.perf_counter() - t0) * 1000.0
|
| 623 |
stage_duration_ms.labels("verifier").observe(dt2)
|
| 624 |
traces.extend(self._trace_list(r_ver2))
|
| 625 |
+
_tag_last_trace_attempt("verifier", _attempt)
|
| 626 |
if not getattr(r_ver2, "trace", None):
|
| 627 |
_fallback_trace("verifier", dt2, r_ver2.ok)
|
| 628 |
+
verified = bool(r_ver2.data and r_ver2.data.get("verified") is True)
|
|
|
|
|
|
|
| 629 |
if r_ver2.data and "sql" in r_ver2.data and r_ver2.data["sql"]:
|
| 630 |
sql = r_ver2.data["sql"]
|
| 631 |
if verified:
|
nl2sql/planner.py
CHANGED
|
@@ -6,6 +6,23 @@ from typing import Any, Dict, List, Tuple, Optional
|
|
| 6 |
__all__ = ["Planner"]
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# --------- Heuristic schema trimming (safe, mypy-clean) ---------
|
| 10 |
def _tokenize_lower(s: str) -> List[str]:
|
| 11 |
return re.findall(r"[a-z_]+", (s or "").lower())
|
|
@@ -14,41 +31,33 @@ def _tokenize_lower(s: str) -> List[str]:
|
|
| 14 |
def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
|
| 15 |
"""
|
| 16 |
Parse plain-text schema into [(table_name, lines)] blocks,
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
blocks: List[Tuple[str, List[str]]] = []
|
| 20 |
cur_name: Optional[str] = None
|
| 21 |
cur_lines: List[str] = []
|
| 22 |
|
| 23 |
-
def _flush()
|
| 24 |
nonlocal cur_name, cur_lines
|
| 25 |
-
if cur_name is not None
|
| 26 |
-
blocks.append((cur_name, cur_lines
|
| 27 |
cur_name, cur_lines = None, []
|
| 28 |
|
| 29 |
-
for
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
if m is not None:
|
| 36 |
-
name = m.group(1)
|
| 37 |
-
started = True
|
| 38 |
-
elif m2 is not None:
|
| 39 |
-
name = m2.group(1)
|
| 40 |
-
started = True
|
| 41 |
-
|
| 42 |
-
if started and name:
|
| 43 |
_flush()
|
| 44 |
-
cur_name =
|
| 45 |
-
cur_lines
|
| 46 |
else:
|
| 47 |
if cur_name is not None:
|
| 48 |
-
cur_lines.append(
|
| 49 |
-
|
| 50 |
-
if cur_name is not None and line.strip().endswith(");"):
|
| 51 |
-
_flush()
|
| 52 |
|
| 53 |
_flush()
|
| 54 |
return blocks
|
|
@@ -64,29 +73,22 @@ def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
|
|
| 64 |
q_toks = set(_tokenize_lower(question))
|
| 65 |
scored: List[Tuple[int, str, List[str]]] = []
|
| 66 |
for name, lines in blocks:
|
| 67 |
-
score = sum(1 for
|
| 68 |
-
cols_line = " ".join(lines)
|
| 69 |
-
cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
|
| 70 |
-
score += min(2, sum(1 for c in cols if c.lower() in q_toks))
|
| 71 |
scored.append((score, name, lines))
|
| 72 |
|
| 73 |
-
scored.sort(key=lambda
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
keep = scored[: max(1, k)]
|
| 77 |
-
|
| 78 |
out_lines: List[str] = []
|
| 79 |
-
for _, _, lines in
|
| 80 |
out_lines.extend(lines)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
return trimmed if trimmed else schema_text
|
| 85 |
except Exception:
|
| 86 |
return schema_text
|
| 87 |
|
| 88 |
|
| 89 |
-
# ------------------------------ Planner ------------------------------
|
| 90 |
class Planner:
|
| 91 |
"""Planner wrapper around the LLM provider."""
|
| 92 |
|
|
@@ -95,26 +97,65 @@ class Planner:
|
|
| 95 |
# ensure model_id is always a str (for mypy)
|
| 96 |
self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
|
| 97 |
# in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
|
| 98 |
-
self._plan_cache: dict[
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
key: tuple[str, int, int] = (
|
| 104 |
self.model_id,
|
| 105 |
hash(user_query or ""),
|
| 106 |
-
hash(
|
| 107 |
)
|
|
|
|
| 108 |
if key in self._plan_cache:
|
| 109 |
-
plan_text, pin, pout, cost = self._plan_cache[key]
|
| 110 |
else:
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return {
|
| 117 |
"plan": plan_text,
|
|
|
|
| 118 |
"usage": {
|
| 119 |
"prompt_tokens": pin,
|
| 120 |
"completion_tokens": pout,
|
|
|
|
| 6 |
__all__ = ["Planner"]
|
| 7 |
|
| 8 |
|
| 9 |
+
def _extract_table_names_from_schema(schema_text: str) -> List[str]:
|
| 10 |
+
"""Best-effort table name extraction from schema preview."""
|
| 11 |
+
if not schema_text:
|
| 12 |
+
return []
|
| 13 |
+
names = re.findall(
|
| 14 |
+
r"(?im)^\s*create\s+table\s+`?([A-Za-z_][A-Za-z0-9_]*)`?\b", schema_text
|
| 15 |
+
)
|
| 16 |
+
# de-dup preserving order
|
| 17 |
+
seen: set[str] = set()
|
| 18 |
+
out: List[str] = []
|
| 19 |
+
for n in names:
|
| 20 |
+
if n not in seen:
|
| 21 |
+
out.append(n)
|
| 22 |
+
seen.add(n)
|
| 23 |
+
return out
|
| 24 |
+
|
| 25 |
+
|
| 26 |
# --------- Heuristic schema trimming (safe, mypy-clean) ---------
|
| 27 |
def _tokenize_lower(s: str) -> List[str]:
|
| 28 |
return re.findall(r"[a-z_]+", (s or "").lower())
|
|
|
|
| 31 |
def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
|
| 32 |
"""
|
| 33 |
Parse plain-text schema into [(table_name, lines)] blocks,
|
| 34 |
+
assuming SQLite preview format like:
|
| 35 |
+
Table: users
|
| 36 |
+
- id
|
| 37 |
+
- name
|
| 38 |
"""
|
| 39 |
blocks: List[Tuple[str, List[str]]] = []
|
| 40 |
cur_name: Optional[str] = None
|
| 41 |
cur_lines: List[str] = []
|
| 42 |
|
| 43 |
+
def _flush():
|
| 44 |
nonlocal cur_name, cur_lines
|
| 45 |
+
if cur_name is not None:
|
| 46 |
+
blocks.append((cur_name, cur_lines))
|
| 47 |
cur_name, cur_lines = None, []
|
| 48 |
|
| 49 |
+
for raw in (schema_text or "").splitlines():
|
| 50 |
+
line = raw.strip()
|
| 51 |
+
if not line:
|
| 52 |
+
continue
|
| 53 |
+
m = re.match(r"^table:\s*([a-zA-Z0-9_]+)\s*$", line, re.IGNORECASE)
|
| 54 |
+
if m:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
_flush()
|
| 56 |
+
cur_name = m.group(1)
|
| 57 |
+
cur_lines = [raw]
|
| 58 |
else:
|
| 59 |
if cur_name is not None:
|
| 60 |
+
cur_lines.append(raw)
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
_flush()
|
| 63 |
return blocks
|
|
|
|
| 73 |
q_toks = set(_tokenize_lower(question))
|
| 74 |
scored: List[Tuple[int, str, List[str]]] = []
|
| 75 |
for name, lines in blocks:
|
| 76 |
+
score = sum(1 for tok in _tokenize_lower(" ".join(lines)) if tok in q_toks)
|
|
|
|
|
|
|
|
|
|
| 77 |
scored.append((score, name, lines))
|
| 78 |
|
| 79 |
+
scored.sort(key=lambda x: (-x[0], x[1]))
|
| 80 |
+
top = scored[:k]
|
| 81 |
+
# Keep stable order by original appearance? We'll keep by score then name for determinism.
|
|
|
|
|
|
|
| 82 |
out_lines: List[str] = []
|
| 83 |
+
for _, _, lines in top:
|
| 84 |
out_lines.extend(lines)
|
| 85 |
+
out_lines.append("") # spacing
|
| 86 |
+
|
| 87 |
+
return "\n".join(out_lines).strip() if out_lines else schema_text
|
|
|
|
| 88 |
except Exception:
|
| 89 |
return schema_text
|
| 90 |
|
| 91 |
|
|
|
|
| 92 |
class Planner:
|
| 93 |
"""Planner wrapper around the LLM provider."""
|
| 94 |
|
|
|
|
| 97 |
# ensure model_id is always a str (for mypy)
|
| 98 |
self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
|
| 99 |
# in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
|
| 100 |
+
self._plan_cache: dict[
|
| 101 |
+
tuple[str, int, int], tuple[str, List[str], int, int, float]
|
| 102 |
+
] = {}
|
| 103 |
+
|
| 104 |
+
def run(
|
| 105 |
+
self,
|
| 106 |
+
*,
|
| 107 |
+
user_query: str,
|
| 108 |
+
schema_preview: str,
|
| 109 |
+
constraints: Optional[List[str]] = None,
|
| 110 |
+
traces: Optional[List[dict]] = None,
|
| 111 |
+
) -> Dict[str, Any]:
|
| 112 |
+
"""Plan the query. Assumes schema_preview is already budgeted upstream."""
|
| 113 |
+
schema_preview = schema_preview or ""
|
| 114 |
+
constraints = constraints or []
|
| 115 |
|
| 116 |
key: tuple[str, int, int] = (
|
| 117 |
self.model_id,
|
| 118 |
hash(user_query or ""),
|
| 119 |
+
hash(schema_preview),
|
| 120 |
)
|
| 121 |
+
|
| 122 |
if key in self._plan_cache:
|
| 123 |
+
plan_text, used_tables, pin, pout, cost = self._plan_cache[key]
|
| 124 |
else:
|
| 125 |
+
# Call provider with backward-compatible kwargs
|
| 126 |
+
try:
|
| 127 |
+
res = self.llm.plan(
|
| 128 |
+
user_query=user_query,
|
| 129 |
+
schema_preview=schema_preview,
|
| 130 |
+
constraints=constraints,
|
| 131 |
+
)
|
| 132 |
+
except TypeError:
|
| 133 |
+
# Older fakes/providers may not accept `constraints`
|
| 134 |
+
res = self.llm.plan(
|
| 135 |
+
user_query=user_query,
|
| 136 |
+
schema_preview=schema_preview,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if not isinstance(res, tuple):
|
| 140 |
+
raise TypeError("LLM plan() must return a tuple")
|
| 141 |
+
|
| 142 |
+
if len(res) == 5:
|
| 143 |
+
plan_text, used_tables, pin, pout, cost = res
|
| 144 |
+
elif len(res) == 4:
|
| 145 |
+
plan_text, pin, pout, cost = res
|
| 146 |
+
used_tables = _extract_table_names_from_schema(schema_preview)
|
| 147 |
+
else:
|
| 148 |
+
raise TypeError("LLM plan() must return 4- or 5-tuple")
|
| 149 |
+
|
| 150 |
+
# Ensure used_tables is always a list[str]
|
| 151 |
+
if not isinstance(used_tables, list):
|
| 152 |
+
used_tables = _extract_table_names_from_schema(schema_preview)
|
| 153 |
+
|
| 154 |
+
self._plan_cache[key] = (plan_text, used_tables, pin, pout, cost)
|
| 155 |
|
| 156 |
return {
|
| 157 |
"plan": plan_text,
|
| 158 |
+
"used_tables": used_tables,
|
| 159 |
"usage": {
|
| 160 |
"prompt_tokens": pin,
|
| 161 |
"completion_tokens": pout,
|
nl2sql/prompts/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt contracts for LLM-facing stages."""
|
| 2 |
+
|
| 3 |
+
from .contracts import (
|
| 4 |
+
PlannerPromptInput,
|
| 5 |
+
PlannerPromptOutput,
|
| 6 |
+
GeneratorPromptInput,
|
| 7 |
+
GeneratorPromptOutput,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"PlannerPromptInput",
|
| 12 |
+
"PlannerPromptOutput",
|
| 13 |
+
"GeneratorPromptInput",
|
| 14 |
+
"GeneratorPromptOutput",
|
| 15 |
+
]
|
nl2sql/prompts/contracts.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# NOTE:
|
| 8 |
+
# These are *prompt contracts* (input/output shapes) for LLM-facing stages.
|
| 9 |
+
# They are intentionally lightweight to keep Block C minimal and low-risk.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class PlannerPromptInput:
|
| 14 |
+
user_query: str
|
| 15 |
+
schema_preview: str # already budgeted at pipeline boundary
|
| 16 |
+
constraints: List[str]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class PlannerPromptOutput:
|
| 21 |
+
plan: str
|
| 22 |
+
used_tables: List[str]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class GeneratorPromptInput:
|
| 27 |
+
user_query: str
|
| 28 |
+
schema_preview: str # already budgeted at pipeline boundary
|
| 29 |
+
plan: str
|
| 30 |
+
constraints: List[str]
|
| 31 |
+
clarify_answers: Optional[Dict[str, Any]] = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class GeneratorPromptOutput:
|
| 36 |
+
sql: str
|
| 37 |
+
rationale: str
|
| 38 |
+
used_tables: List[str]
|