github-actions[bot] commited on
Commit
e70c579
·
1 Parent(s): c758204

Sync from GitHub main

Browse files
adapters/llm/openai_provider.py CHANGED
@@ -43,6 +43,10 @@ class OpenAIProvider(LLMProvider):
43
  """Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
44
  return dict(self._last_usage)
45
 
 
 
 
 
46
  def __init__(self) -> None:
47
  """Initialize OpenAI client with config from environment."""
48
  api_key, base_url, model = _resolve_api_config()
@@ -84,7 +88,7 @@ Database Schema:
84
 
85
  Create a step-by-step plan to answer this question with SQL."""
86
 
87
- completion = self.client.chat.completions.create(
88
  model=self.model,
89
  messages=[
90
  {"role": "system", "content": system_prompt},
@@ -181,7 +185,7 @@ Now generate the SQL for the given question:"""
181
  if clarify_answers:
182
  user_prompt += f"\n\nAdditional context: {clarify_answers}"
183
 
184
- completion = self.client.chat.completions.create(
185
  model=self.model,
186
  messages=[
187
  {"role": "system", "content": system_prompt},
@@ -316,7 +320,7 @@ Database Schema:
316
 
317
  Return the corrected SQL (keep it simple):"""
318
 
319
- completion = self.client.chat.completions.create(
320
  model=self.model,
321
  messages=[
322
  {"role": "system", "content": system_prompt},
@@ -419,7 +423,7 @@ Database Schema:
419
  Please answer these clarification questions:
420
  {chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
421
 
422
- completion = self.client.chat.completions.create(
423
  model=self.model,
424
  messages=[
425
  {"role": "system", "content": system_prompt},
 
43
  """Return metadata of the last LLM call (tokens, cost, sql_length, kind)."""
44
  return dict(self._last_usage)
45
 
46
+ def _create_chat_completion(self, **kwargs):
47
+ """OpenAI SDK seam for stable unit testing."""
48
+ return self.client.chat.completions.create(**kwargs)
49
+
50
  def __init__(self) -> None:
51
  """Initialize OpenAI client with config from environment."""
52
  api_key, base_url, model = _resolve_api_config()
 
88
 
89
  Create a step-by-step plan to answer this question with SQL."""
90
 
91
+ completion = self._create_chat_completion(
92
  model=self.model,
93
  messages=[
94
  {"role": "system", "content": system_prompt},
 
185
  if clarify_answers:
186
  user_prompt += f"\n\nAdditional context: {clarify_answers}"
187
 
188
+ completion = self._create_chat_completion(
189
  model=self.model,
190
  messages=[
191
  {"role": "system", "content": system_prompt},
 
320
 
321
  Return the corrected SQL (keep it simple):"""
322
 
323
+ completion = self._create_chat_completion(
324
  model=self.model,
325
  messages=[
326
  {"role": "system", "content": system_prompt},
 
423
  Please answer these clarification questions:
424
  {chr(10).join(f"{i + 1}. {q}" for i, q in enumerate(questions))}"""
425
 
426
+ completion = self._create_chat_completion(
427
  model=self.model,
428
  messages=[
429
  {"role": "system", "content": system_prompt},
nl2sql/generator.py CHANGED
@@ -1,8 +1,11 @@
1
  from __future__ import annotations
 
2
  import time
3
  from typing import Optional, Dict, Any
4
- from nl2sql.types import StageResult, StageTrace
5
  from adapters.llm.base import LLMProvider
 
 
6
 
7
 
8
  class Generator:
@@ -20,6 +23,7 @@ class Generator:
20
  clarify_answers: Optional[Dict[str, Any]] = None,
21
  ) -> StageResult:
22
  t0 = time.perf_counter()
 
23
  try:
24
  res = self.llm.generate_sql(
25
  user_query=user_query,
@@ -28,15 +32,23 @@ class Generator:
28
  clarify_answers=clarify_answers or {},
29
  )
30
  except Exception as e:
31
- return StageResult(ok=False, error=[f"Generator failed: {e}"])
 
 
 
 
 
 
32
 
33
- # Expect a 5-tuple
34
  if not isinstance(res, tuple) or len(res) != 5:
35
  return StageResult(
36
  ok=False,
37
  error=[
38
  "Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
39
  ],
 
 
40
  )
41
 
42
  sql, rationale, t_in, t_out, cost = res
@@ -44,12 +56,23 @@ class Generator:
44
  # Type/shape checks
45
  if not isinstance(sql, str) or not sql.strip():
46
  return StageResult(
47
- ok=False, error=["Generator produced empty or non-string SQL"]
 
 
 
48
  )
 
 
49
  if not sql.lower().lstrip().startswith("select"):
50
- return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
 
 
 
 
 
51
 
52
- rationale = rationale or "" # safe length
 
53
  trace = StageTrace(
54
  stage=self.name,
55
  duration_ms=(time.perf_counter() - t0) * 1000.0,
@@ -60,5 +83,9 @@ class Generator:
60
  )
61
 
62
  return StageResult(
63
- ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
 
 
 
 
64
  )
 
1
  from __future__ import annotations
2
+
3
  import time
4
  from typing import Optional, Dict, Any
5
+
6
  from adapters.llm.base import LLMProvider
7
+ from nl2sql.errors.codes import ErrorCode
8
+ from nl2sql.types import StageResult, StageTrace
9
 
10
 
11
  class Generator:
 
23
  clarify_answers: Optional[Dict[str, Any]] = None,
24
  ) -> StageResult:
25
  t0 = time.perf_counter()
26
+
27
  try:
28
  res = self.llm.generate_sql(
29
  user_query=user_query,
 
32
  clarify_answers=clarify_answers or {},
33
  )
34
  except Exception as e:
35
+ # Provider/transport errors or unexpected runtime issues.
36
+ return StageResult(
37
+ ok=False,
38
+ error=[f"Generator failed: {e}"],
39
+ error_code=ErrorCode.LLM_BAD_OUTPUT,
40
+ trace=None,
41
+ )
42
 
43
+ # Contract: expect a 5-tuple (sql, rationale, token_in, token_out, cost_usd)
44
  if not isinstance(res, tuple) or len(res) != 5:
45
  return StageResult(
46
  ok=False,
47
  error=[
48
  "Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
49
  ],
50
+ error_code=ErrorCode.LLM_BAD_OUTPUT,
51
+ trace=None,
52
  )
53
 
54
  sql, rationale, t_in, t_out, cost = res
 
56
  # Type/shape checks
57
  if not isinstance(sql, str) or not sql.strip():
58
  return StageResult(
59
+ ok=False,
60
+ error=["Generator produced empty or non-string SQL"],
61
+ error_code=ErrorCode.LLM_BAD_OUTPUT,
62
+ trace=None,
63
  )
64
+
65
+ # Enforce SELECT-only at the boundary (fast fail before hitting later stages).
66
  if not sql.lower().lstrip().startswith("select"):
67
+ return StageResult(
68
+ ok=False,
69
+ error=[f"Generated non-SELECT SQL: {sql}"],
70
+ error_code=ErrorCode.SAFETY_NON_SELECT,
71
+ trace=None,
72
+ )
73
 
74
+ # Normalize rationale to a string
75
+ rationale = rationale or ""
76
  trace = StageTrace(
77
  stage=self.name,
78
  duration_ms=(time.perf_counter() - t0) * 1000.0,
 
83
  )
84
 
85
  return StageResult(
86
+ ok=True,
87
+ data={"sql": sql, "rationale": rationale},
88
+ trace=trace,
89
+ error_code=None,
90
+ retryable=None,
91
  )
nl2sql/pipeline.py CHANGED
@@ -3,6 +3,7 @@ import traceback
3
  from dataclasses import dataclass
4
  from typing import Dict, Any, Optional, List
5
  import time
 
6
 
7
  from nl2sql.types import StageResult
8
  from nl2sql.ambiguity_detector import AmbiguityDetector
@@ -239,6 +240,25 @@ class Pipeline:
239
  "schema_preview": kwargs.get("schema_preview", ""),
240
  }
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def run(
243
  self,
244
  *,
@@ -315,6 +335,7 @@ class Pipeline:
315
  ambiguous=False,
316
  error=True,
317
  details=r_plan.error,
 
318
  questions=None,
319
  sql=None,
320
  rationale=None,
@@ -347,6 +368,7 @@ class Pipeline:
347
  ambiguous=False,
348
  error=True,
349
  details=r_gen.error,
 
350
  questions=None,
351
  sql=None,
352
  rationale=None,
@@ -368,6 +390,7 @@ class Pipeline:
368
  ambiguous=False,
369
  error=True,
370
  details=["empty_sql"],
 
371
  questions=None,
372
  sql=None,
373
  rationale=rationale,
@@ -397,6 +420,7 @@ class Pipeline:
397
  ambiguous=False,
398
  error=True,
399
  details=r_safe.error,
 
400
  questions=None,
401
  sql=sql,
402
  rationale=rationale,
@@ -431,12 +455,11 @@ class Pipeline:
431
  t0 = time.perf_counter()
432
  r_ver = self._run_with_repair(
433
  "verifier",
434
- self.verifier.run,
435
  repair_input_builder=self._sql_repair_input_builder,
436
  max_attempts=1,
437
  sql=sql,
438
  exec_result=(r_exec.data or {}),
439
- adapter=getattr(self.executor, "adapter", None),
440
  traces=traces,
441
  )
442
  dt = (time.perf_counter() - t0) * 1000.0
@@ -522,10 +545,9 @@ class Pipeline:
522
  # verifier again
523
  t0 = time.perf_counter()
524
  r_ver2 = self._safe_stage(
525
- self.verifier.run,
526
  sql=sql,
527
  exec_result=(r_exec2.data or {}),
528
- adapter=getattr(self.executor, "adapter", None),
529
  )
530
  dt2 = (time.perf_counter() - t0) * 1000.0
531
  stage_duration_ms.labels("verifier").observe(dt2)
 
3
  from dataclasses import dataclass
4
  from typing import Dict, Any, Optional, List
5
  import time
6
+ import inspect
7
 
8
  from nl2sql.types import StageResult
9
  from nl2sql.ambiguity_detector import AmbiguityDetector
 
240
  "schema_preview": kwargs.get("schema_preview", ""),
241
  }
242
 
243
+ def _call_verifier(self, *, sql: str, exec_result: Dict[str, Any]) -> StageResult:
244
+ """
245
+ Call verifier with a backward-compatible signature.
246
+ Some verifiers accept `adapter=...`, some don't.
247
+ """
248
+ kwargs: Dict[str, Any] = {"sql": sql, "exec_result": exec_result}
249
+
250
+ adapter = getattr(self.executor, "adapter", None)
251
+ if adapter is not None:
252
+ try:
253
+ params = inspect.signature(self.verifier.run).parameters
254
+ if "adapter" in params:
255
+ kwargs["adapter"] = adapter
256
+ except (TypeError, ValueError):
257
+ # If signature introspection fails, fall back to the minimal call.
258
+ pass
259
+
260
+ return self.verifier.run(**kwargs)
261
+
262
  def run(
263
  self,
264
  *,
 
335
  ambiguous=False,
336
  error=True,
337
  details=r_plan.error,
338
+ error_code=ErrorCode.PIPELINE_CRASH,
339
  questions=None,
340
  sql=None,
341
  rationale=None,
 
368
  ambiguous=False,
369
  error=True,
370
  details=r_gen.error,
371
+ error_code=ErrorCode.LLM_BAD_OUTPUT,
372
  questions=None,
373
  sql=None,
374
  rationale=None,
 
390
  ambiguous=False,
391
  error=True,
392
  details=["empty_sql"],
393
+ error_code=ErrorCode.LLM_BAD_OUTPUT,
394
  questions=None,
395
  sql=None,
396
  rationale=rationale,
 
420
  ambiguous=False,
421
  error=True,
422
  details=r_safe.error,
423
+ error_code=r_safe.error_code,
424
  questions=None,
425
  sql=sql,
426
  rationale=rationale,
 
455
  t0 = time.perf_counter()
456
  r_ver = self._run_with_repair(
457
  "verifier",
458
+ self._call_verifier,
459
  repair_input_builder=self._sql_repair_input_builder,
460
  max_attempts=1,
461
  sql=sql,
462
  exec_result=(r_exec.data or {}),
 
463
  traces=traces,
464
  )
465
  dt = (time.perf_counter() - t0) * 1000.0
 
545
  # verifier again
546
  t0 = time.perf_counter()
547
  r_ver2 = self._safe_stage(
548
+ self._call_verifier,
549
  sql=sql,
550
  exec_result=(r_exec2.data or {}),
 
551
  )
552
  dt2 = (time.perf_counter() - t0) * 1000.0
553
  stage_duration_ms.labels("verifier").observe(dt2)