Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
c1bc4eb
1
Parent(s):
646d80b
style: format code with ruff
Browse files- adapters/db/base.py +3 -1
- adapters/db/postgres_adapter.py +6 -2
- adapters/db/sqlite_adapter.py +1 -0
- adapters/llm/base.py +15 -4
- adapters/llm/openai_provider.py +33 -15
- app/main.py +7 -7
- app/routers/nl2sql.py +6 -3
- app/schemas.py +5 -0
- benchmarks/evaluate_spider.py +34 -14
- benchmarks/run.py +34 -15
- benchmarks/spider_loader.py +11 -8
- config.py +25 -7
- nl2sql/ambiguity_detector.py +3 -2
- nl2sql/executor.py +14 -5
- nl2sql/generator.py +22 -7
- nl2sql/pipeline.py +59 -25
- nl2sql/planner.py +13 -3
- nl2sql/repair.py +16 -6
- nl2sql/safety.py +17 -5
- nl2sql/stubs.py +9 -3
- nl2sql/types.py +2 -0
- nl2sql/verifier.py +22 -10
- tests/conftest.py +1 -1
- tests/test_ambiguity.py +7 -2
- tests/test_executor.py +2 -0
- tests/test_generator.py +7 -4
- tests/test_nl2sql_router.py +5 -1
- tests/test_openai_provider.py +15 -13
- tests/test_pipeline_integration.py +21 -18
- tests/test_safety.py +38 -20
- tests/test_stage_types.py +3 -0
- ui/benchmark_app.py +16 -6
adapters/db/base.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
from typing import Tuple, List, Dict, Any, Protocol
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
|
|
|
|
| 4 |
class DBAdapter(Protocol):
|
| 5 |
"""Abstract database adapter for read-only queries."""
|
|
|
|
| 6 |
name: str
|
| 7 |
dialect: str
|
| 8 |
|
|
@@ -10,4 +12,4 @@ class DBAdapter(Protocol):
|
|
| 10 |
"""Generate a readable summary of the database schema with optional sample rows per table."""
|
| 11 |
|
| 12 |
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 13 |
-
"""Execute a SELECT query and return (rows, columns)."""
|
|
|
|
| 1 |
from typing import Tuple, List, Dict, Any, Protocol
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
|
| 4 |
+
|
| 5 |
class DBAdapter(Protocol):
|
| 6 |
"""Abstract database adapter for read-only queries."""
|
| 7 |
+
|
| 8 |
name: str
|
| 9 |
dialect: str
|
| 10 |
|
|
|
|
| 12 |
"""Generate a readable summary of the database schema with optional sample rows per table."""
|
| 13 |
|
| 14 |
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 15 |
+
"""Execute a SELECT query and return (rows, columns)."""
|
adapters/db/postgres_adapter.py
CHANGED
|
@@ -2,6 +2,7 @@ import psycopg
|
|
| 2 |
from typing import Any, List, Tuple
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
|
|
|
| 5 |
class PostgresAdapter(DBAdapter):
|
| 6 |
name = "postgres"
|
| 7 |
dialect = "postgres"
|
|
@@ -24,11 +25,14 @@ class PostgresAdapter(DBAdapter):
|
|
| 24 |
tables = [t[0] for t in cur.fetchall()]
|
| 25 |
lines = []
|
| 26 |
for t in tables:
|
| 27 |
-
cur.execute(
|
|
|
|
| 28 |
SELECT column_name, data_type
|
| 29 |
FROM information_schema.columns
|
| 30 |
WHERE table_name = %s;
|
| 31 |
-
""",
|
|
|
|
|
|
|
| 32 |
cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
|
| 33 |
lines.append(f"- {t} ({', '.join(cols)})")
|
| 34 |
return "\n".join(lines)
|
|
|
|
| 2 |
from typing import Any, List, Tuple
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
| 5 |
+
|
| 6 |
class PostgresAdapter(DBAdapter):
|
| 7 |
name = "postgres"
|
| 8 |
dialect = "postgres"
|
|
|
|
| 25 |
tables = [t[0] for t in cur.fetchall()]
|
| 26 |
lines = []
|
| 27 |
for t in tables:
|
| 28 |
+
cur.execute(
|
| 29 |
+
f"""
|
| 30 |
SELECT column_name, data_type
|
| 31 |
FROM information_schema.columns
|
| 32 |
WHERE table_name = %s;
|
| 33 |
+
""",
|
| 34 |
+
(t,),
|
| 35 |
+
)
|
| 36 |
cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
|
| 37 |
lines.append(f"- {t} ({', '.join(cols)})")
|
| 38 |
return "\n".join(lines)
|
adapters/db/sqlite_adapter.py
CHANGED
|
@@ -2,6 +2,7 @@ import sqlite3
|
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
|
|
|
| 5 |
class SQLiteAdapter(DBAdapter):
|
| 6 |
name = "sqlite"
|
| 7 |
dialect = "sqlite"
|
|
|
|
| 2 |
from typing import List, Tuple, Any
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
| 5 |
+
|
| 6 |
class SQLiteAdapter(DBAdapter):
|
| 7 |
name = "sqlite"
|
| 8 |
dialect = "sqlite"
|
adapters/llm/base.py
CHANGED
|
@@ -2,15 +2,26 @@
|
|
| 2 |
from __future__ import annotations
|
| 3 |
from typing import Tuple, List, Dict, Any, Protocol
|
| 4 |
|
|
|
|
| 5 |
class LLMProvider(Protocol):
|
| 6 |
provider_id: str
|
| 7 |
|
| 8 |
-
def plan(
|
|
|
|
|
|
|
| 9 |
"""Return (plan_text, token_in, token_out, cost_usd)."""
|
| 10 |
|
| 11 |
-
def generate_sql(
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
| 14 |
|
| 15 |
-
def repair(
|
|
|
|
|
|
|
| 16 |
"""Return (patched_sql, token_in, token_out, cost_usd)."""
|
|
|
|
| 2 |
from __future__ import annotations
|
| 3 |
from typing import Tuple, List, Dict, Any, Protocol
|
| 4 |
|
| 5 |
+
|
| 6 |
class LLMProvider(Protocol):
|
| 7 |
provider_id: str
|
| 8 |
|
| 9 |
+
def plan(
|
| 10 |
+
self, *, user_query: str, schema_preview: str
|
| 11 |
+
) -> Tuple[str, int, int, float]:
|
| 12 |
"""Return (plan_text, token_in, token_out, cost_usd)."""
|
| 13 |
|
| 14 |
+
def generate_sql(
|
| 15 |
+
self,
|
| 16 |
+
*,
|
| 17 |
+
user_query: str,
|
| 18 |
+
schema_preview: str,
|
| 19 |
+
plan_text: str,
|
| 20 |
+
clarify_answers: Dict[str, Any] | None = None,
|
| 21 |
+
) -> Tuple[str, str, int, int, float]:
|
| 22 |
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
| 23 |
|
| 24 |
+
def repair(
|
| 25 |
+
self, *, sql: str, error_msg: str, schema_preview: str
|
| 26 |
+
) -> Tuple[str, int, int, float]:
|
| 27 |
"""Return (patched_sql, token_in, token_out, cost_usd)."""
|
adapters/llm/openai_provider.py
CHANGED
|
@@ -11,14 +11,13 @@ from openai import OpenAI
|
|
| 11 |
# - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
class OpenAIProvider(LLMProvider):
|
| 16 |
provider_id = "openai"
|
| 17 |
|
| 18 |
def __init__(self) -> None:
|
| 19 |
self.client = OpenAI(
|
| 20 |
api_key=os.environ["OPENAI_API_KEY"],
|
| 21 |
-
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 22 |
)
|
| 23 |
self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
|
| 24 |
|
|
@@ -27,16 +26,25 @@ class OpenAIProvider(LLMProvider):
|
|
| 27 |
model=self.model,
|
| 28 |
messages=[
|
| 29 |
{"role": "system", "content": "You create SQL query plans."},
|
| 30 |
-
{
|
|
|
|
|
|
|
|
|
|
| 31 |
],
|
| 32 |
-
temperature=0
|
| 33 |
)
|
| 34 |
msg = completion.choices[0].message.content
|
| 35 |
usage = completion.usage
|
| 36 |
-
return
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
def generate_sql(
|
|
|
|
|
|
|
| 40 |
prompt = f"""
|
| 41 |
You are a precise SQL generator.
|
| 42 |
Return ONLY valid JSON with two keys: "sql" and "rationale".
|
|
@@ -60,9 +68,9 @@ class OpenAIProvider(LLMProvider):
|
|
| 60 |
model=self.model,
|
| 61 |
messages=[
|
| 62 |
{"role": "system", "content": "You convert natural language to SQL."},
|
| 63 |
-
{"role": "user", "content": prompt}
|
| 64 |
],
|
| 65 |
-
temperature=0
|
| 66 |
)
|
| 67 |
content = completion.choices[0].message.content.strip()
|
| 68 |
usage = completion.usage # ← لازم داریم
|
|
@@ -78,7 +86,7 @@ class OpenAIProvider(LLMProvider):
|
|
| 78 |
end = content.rfind("}")
|
| 79 |
if start != -1 and end != -1:
|
| 80 |
try:
|
| 81 |
-
parsed = json.loads(content[start:end + 1])
|
| 82 |
except Exception:
|
| 83 |
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 84 |
else:
|
|
@@ -93,19 +101,29 @@ class OpenAIProvider(LLMProvider):
|
|
| 93 |
# IMPORTANT: return the expected 5-tuple
|
| 94 |
return sql, rationale, t_in, t_out, cost
|
| 95 |
|
| 96 |
-
|
| 97 |
def repair(self, *, sql, error_msg, schema_preview):
|
| 98 |
completion = self.client.chat.completions.create(
|
| 99 |
model=self.model,
|
| 100 |
messages=[
|
| 101 |
-
{
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
],
|
| 104 |
-
temperature=0
|
| 105 |
)
|
| 106 |
msg = completion.choices[0].message.content
|
| 107 |
usage = completion.usage
|
| 108 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def _estimate_cost(self, usage):
|
| 111 |
# Rough estimation example — can be refined with official token pricing
|
|
|
|
| 11 |
# - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
class OpenAIProvider(LLMProvider):
|
| 15 |
provider_id = "openai"
|
| 16 |
|
| 17 |
def __init__(self) -> None:
|
| 18 |
self.client = OpenAI(
|
| 19 |
api_key=os.environ["OPENAI_API_KEY"],
|
| 20 |
+
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
| 21 |
)
|
| 22 |
self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
|
| 23 |
|
|
|
|
| 26 |
model=self.model,
|
| 27 |
messages=[
|
| 28 |
{"role": "system", "content": "You create SQL query plans."},
|
| 29 |
+
{
|
| 30 |
+
"role": "user",
|
| 31 |
+
"content": f"Query: {user_query}\nSchema:\n{schema_preview}",
|
| 32 |
+
},
|
| 33 |
],
|
| 34 |
+
temperature=0,
|
| 35 |
)
|
| 36 |
msg = completion.choices[0].message.content
|
| 37 |
usage = completion.usage
|
| 38 |
+
return (
|
| 39 |
+
msg,
|
| 40 |
+
usage.prompt_tokens,
|
| 41 |
+
usage.completion_tokens,
|
| 42 |
+
self._estimate_cost(usage),
|
| 43 |
+
)
|
| 44 |
|
| 45 |
+
def generate_sql(
|
| 46 |
+
self, *, user_query, schema_preview, plan_text, clarify_answers=None
|
| 47 |
+
):
|
| 48 |
prompt = f"""
|
| 49 |
You are a precise SQL generator.
|
| 50 |
Return ONLY valid JSON with two keys: "sql" and "rationale".
|
|
|
|
| 68 |
model=self.model,
|
| 69 |
messages=[
|
| 70 |
{"role": "system", "content": "You convert natural language to SQL."},
|
| 71 |
+
{"role": "user", "content": prompt},
|
| 72 |
],
|
| 73 |
+
temperature=0,
|
| 74 |
)
|
| 75 |
content = completion.choices[0].message.content.strip()
|
| 76 |
usage = completion.usage # ← لازم داریم
|
|
|
|
| 86 |
end = content.rfind("}")
|
| 87 |
if start != -1 and end != -1:
|
| 88 |
try:
|
| 89 |
+
parsed = json.loads(content[start : end + 1])
|
| 90 |
except Exception:
|
| 91 |
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 92 |
else:
|
|
|
|
| 101 |
# IMPORTANT: return the expected 5-tuple
|
| 102 |
return sql, rationale, t_in, t_out, cost
|
| 103 |
|
|
|
|
| 104 |
def repair(self, *, sql, error_msg, schema_preview):
|
| 105 |
completion = self.client.chat.completions.create(
|
| 106 |
model=self.model,
|
| 107 |
messages=[
|
| 108 |
+
{
|
| 109 |
+
"role": "system",
|
| 110 |
+
"content": "You fix SQL queries keeping them SELECT-only.",
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"role": "user",
|
| 114 |
+
"content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}",
|
| 115 |
+
},
|
| 116 |
],
|
| 117 |
+
temperature=0,
|
| 118 |
)
|
| 119 |
msg = completion.choices[0].message.content
|
| 120 |
usage = completion.usage
|
| 121 |
+
return (
|
| 122 |
+
msg,
|
| 123 |
+
usage.prompt_tokens,
|
| 124 |
+
usage.completion_tokens,
|
| 125 |
+
self._estimate_cost(usage),
|
| 126 |
+
)
|
| 127 |
|
| 128 |
def _estimate_cost(self, usage):
|
| 129 |
# Rough estimation example — can be refined with official token pricing
|
app/main.py
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
from dotenv import load_dotenv
|
|
|
|
| 2 |
load_dotenv()
|
| 3 |
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from app.routers import nl2sql
|
|
|
|
| 6 |
app = FastAPI(
|
| 7 |
title="NL2SQL Copilot Prototype",
|
| 8 |
version="0.1.0",
|
| 9 |
-
description="Natural Language -> SQL Copilot API"
|
| 10 |
)
|
| 11 |
|
| 12 |
app.include_router(nl2sql.router, prefix="/api/v1")
|
| 13 |
|
|
|
|
| 14 |
@app.get("/healthz")
|
| 15 |
def health_check():
|
| 16 |
return {"status": "ok"}
|
| 17 |
|
|
|
|
| 18 |
@app.get("/")
|
| 19 |
def root():
|
| 20 |
return {"status": "ok", "message": "NL2SQL Copilot API is running"}
|
| 21 |
|
|
|
|
| 22 |
@app.get("/health")
|
| 23 |
def health():
|
| 24 |
-
return {
|
| 25 |
-
"status": "ok",
|
| 26 |
-
"db": "connected",
|
| 27 |
-
"llm": "reachable",
|
| 28 |
-
"uptime_sec": 123.4
|
| 29 |
-
}
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
load_dotenv()
|
| 4 |
|
| 5 |
from fastapi import FastAPI
|
| 6 |
from app.routers import nl2sql
|
| 7 |
+
|
| 8 |
app = FastAPI(
|
| 9 |
title="NL2SQL Copilot Prototype",
|
| 10 |
version="0.1.0",
|
| 11 |
+
description="Natural Language -> SQL Copilot API",
|
| 12 |
)
|
| 13 |
|
| 14 |
app.include_router(nl2sql.router, prefix="/api/v1")
|
| 15 |
|
| 16 |
+
|
| 17 |
@app.get("/healthz")
|
| 18 |
def health_check():
|
| 19 |
return {"status": "ok"}
|
| 20 |
|
| 21 |
+
|
| 22 |
@app.get("/")
|
| 23 |
def root():
|
| 24 |
return {"status": "ok", "message": "NL2SQL Copilot API is running"}
|
| 25 |
|
| 26 |
+
|
| 27 |
@app.get("/health")
|
| 28 |
def health():
|
| 29 |
+
return {"status": "ok", "db": "connected", "llm": "reachable", "uptime_sec": 123.4}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/routers/nl2sql.py
CHANGED
|
@@ -19,7 +19,6 @@ import os
|
|
| 19 |
router = APIRouter(prefix="/nl2sql")
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
if os.getenv("DB_MODE", "sqlite") == "postgres":
|
| 24 |
_db = PostgresAdapter(os.environ["POSTGRES_DSN"])
|
| 25 |
else:
|
|
@@ -40,7 +39,7 @@ _pipeline = Pipeline(
|
|
| 40 |
safety=Safety(),
|
| 41 |
executor=_executor,
|
| 42 |
verifier=_verifier,
|
| 43 |
-
repair=_repair
|
| 44 |
)
|
| 45 |
|
| 46 |
|
|
@@ -48,6 +47,7 @@ def _to_dict(obj):
|
|
| 48 |
"""Helper: safely convert dataclass → dict."""
|
| 49 |
return asdict(obj) if is_dataclass(obj) else obj
|
| 50 |
|
|
|
|
| 51 |
def _round_trace(t: dict) -> dict:
|
| 52 |
if t.get("cost_usd") is not None:
|
| 53 |
t["cost_usd"] = round(t["cost_usd"], 6)
|
|
@@ -55,9 +55,12 @@ def _round_trace(t: dict) -> dict:
|
|
| 55 |
t["duration_ms"] = round(t["duration_ms"], 2)
|
| 56 |
return t
|
| 57 |
|
|
|
|
| 58 |
@router.post("", name="nl2sql_handler")
|
| 59 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 60 |
-
result = _pipeline.run(
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# --- Ensure result type ---
|
| 63 |
if not isinstance(result, StageResult):
|
|
|
|
| 19 |
router = APIRouter(prefix="/nl2sql")
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
if os.getenv("DB_MODE", "sqlite") == "postgres":
|
| 23 |
_db = PostgresAdapter(os.environ["POSTGRES_DSN"])
|
| 24 |
else:
|
|
|
|
| 39 |
safety=Safety(),
|
| 40 |
executor=_executor,
|
| 41 |
verifier=_verifier,
|
| 42 |
+
repair=_repair,
|
| 43 |
)
|
| 44 |
|
| 45 |
|
|
|
|
| 47 |
"""Helper: safely convert dataclass → dict."""
|
| 48 |
return asdict(obj) if is_dataclass(obj) else obj
|
| 49 |
|
| 50 |
+
|
| 51 |
def _round_trace(t: dict) -> dict:
|
| 52 |
if t.get("cost_usd") is not None:
|
| 53 |
t["cost_usd"] = round(t["cost_usd"], 6)
|
|
|
|
| 55 |
t["duration_ms"] = round(t["duration_ms"], 2)
|
| 56 |
return t
|
| 57 |
|
| 58 |
+
|
| 59 |
@router.post("", name="nl2sql_handler")
|
| 60 |
def nl2sql_handler(request: NL2SQLRequest):
|
| 61 |
+
result = _pipeline.run(
|
| 62 |
+
user_query=request.query, schema_preview=request.schema_preview
|
| 63 |
+
)
|
| 64 |
|
| 65 |
# --- Ensure result type ---
|
| 66 |
if not isinstance(result, StageResult):
|
app/schemas.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
from typing import List, Optional, Any, Dict
|
| 3 |
|
|
|
|
| 4 |
class NL2SQLRequest(BaseModel):
|
| 5 |
query: str
|
| 6 |
schema_preview: str
|
| 7 |
db_name: Optional[str] = "default"
|
| 8 |
|
|
|
|
| 9 |
class TraceModel(BaseModel):
|
| 10 |
stage: str
|
| 11 |
duration_ms: float
|
|
@@ -14,16 +16,19 @@ class TraceModel(BaseModel):
|
|
| 14 |
cost_usd: float | None = 0
|
| 15 |
notes: Dict[str, Any] | None = None
|
| 16 |
|
|
|
|
| 17 |
class NL2SQLResponse(BaseModel):
|
| 18 |
ambiguous: bool = False
|
| 19 |
sql: str
|
| 20 |
rationale: Optional[str] = None
|
| 21 |
traces: List[TraceModel] = []
|
| 22 |
|
|
|
|
| 23 |
class ClarifyResponse(BaseModel):
|
| 24 |
ambiguous: bool = True
|
| 25 |
questions: List[str]
|
| 26 |
|
|
|
|
| 27 |
class ErrorResponse(BaseModel):
|
| 28 |
error: str
|
| 29 |
details: List[str] | None = None
|
|
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
from typing import List, Optional, Any, Dict
|
| 3 |
|
| 4 |
+
|
| 5 |
class NL2SQLRequest(BaseModel):
|
| 6 |
query: str
|
| 7 |
schema_preview: str
|
| 8 |
db_name: Optional[str] = "default"
|
| 9 |
|
| 10 |
+
|
| 11 |
class TraceModel(BaseModel):
|
| 12 |
stage: str
|
| 13 |
duration_ms: float
|
|
|
|
| 16 |
cost_usd: float | None = 0
|
| 17 |
notes: Dict[str, Any] | None = None
|
| 18 |
|
| 19 |
+
|
| 20 |
class NL2SQLResponse(BaseModel):
|
| 21 |
ambiguous: bool = False
|
| 22 |
sql: str
|
| 23 |
rationale: Optional[str] = None
|
| 24 |
traces: List[TraceModel] = []
|
| 25 |
|
| 26 |
+
|
| 27 |
class ClarifyResponse(BaseModel):
|
| 28 |
ambiguous: bool = True
|
| 29 |
questions: List[str]
|
| 30 |
|
| 31 |
+
|
| 32 |
class ErrorResponse(BaseModel):
|
| 33 |
error: str
|
| 34 |
details: List[str] | None = None
|
benchmarks/evaluate_spider.py
CHANGED
|
@@ -13,16 +13,19 @@ from sqlglot.errors import ParseError
|
|
| 13 |
LOG_DIR = Path("logs/spider_eval")
|
| 14 |
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 15 |
|
|
|
|
| 16 |
def normalize_sql(sql: str) -> str:
|
| 17 |
# نسخه ساده؛ میتونی قویترش کنی با پارس + بازسازی
|
| 18 |
return " ".join(sql.lower().strip().split())
|
| 19 |
|
|
|
|
| 20 |
def compare_results(pred_rows, gold_rows):
|
| 21 |
if pred_rows is None or gold_rows is None:
|
| 22 |
return False
|
| 23 |
# اگر ترتیب مهم نیست
|
| 24 |
return set(pred_rows) == set(gold_rows)
|
| 25 |
|
|
|
|
| 26 |
def try_execute_sql(sql_db, sql, timeout: float = None):
|
| 27 |
start = time.time()
|
| 28 |
try:
|
|
@@ -31,6 +34,7 @@ def try_execute_sql(sql_db, sql, timeout: float = None):
|
|
| 31 |
except Exception as e:
|
| 32 |
return None, time.time() - start, str(e)
|
| 33 |
|
|
|
|
| 34 |
def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
|
| 35 |
try:
|
| 36 |
ast_pred = parse_one(sql_pred)
|
|
@@ -54,13 +58,19 @@ def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
|
|
| 54 |
norm_gold = normalize_ast(ast_gold)
|
| 55 |
return norm_prd == norm_gold
|
| 56 |
|
|
|
|
| 57 |
def get_git_commit_hash() -> str:
|
| 58 |
try:
|
| 59 |
-
out =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return out
|
| 61 |
except Exception:
|
| 62 |
return "UNKNOWN"
|
| 63 |
|
|
|
|
| 64 |
FORBIDDEN_NODES = (
|
| 65 |
exp.Insert,
|
| 66 |
exp.Delete,
|
|
@@ -72,6 +82,7 @@ FORBIDDEN_NODES = (
|
|
| 72 |
exp.Create,
|
| 73 |
)
|
| 74 |
|
|
|
|
| 75 |
def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
|
| 76 |
try:
|
| 77 |
ast = parse_one(sql, read=dialect)
|
|
@@ -84,6 +95,7 @@ def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
|
|
| 84 |
return False
|
| 85 |
return True
|
| 86 |
|
|
|
|
| 87 |
def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
| 88 |
data = load_spider_sqlite(split)
|
| 89 |
if len(data) < limit:
|
|
@@ -94,8 +106,8 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 94 |
commit_hash = get_git_commit_hash()
|
| 95 |
start_ts = int(time.time())
|
| 96 |
|
| 97 |
-
pred_txt
|
| 98 |
-
gold_txt
|
| 99 |
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
|
| 100 |
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
|
| 101 |
|
|
@@ -112,10 +124,11 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 112 |
pass
|
| 113 |
|
| 114 |
write_header = not results_fn.exists()
|
| 115 |
-
with
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
| 119 |
if write_header:
|
| 120 |
header = {
|
| 121 |
"commit_hash": commit_hash,
|
|
@@ -228,21 +241,28 @@ def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
|
| 228 |
if sleep_time > 0:
|
| 229 |
time.sleep(sleep_time)
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
| 233 |
total_valid = len(valid)
|
| 234 |
total_all = len(agg)
|
| 235 |
if total_valid == 0:
|
| 236 |
print("No valid examples to compute metrics")
|
| 237 |
return
|
| 238 |
|
| 239 |
-
em_count
|
| 240 |
em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
|
| 241 |
-
exec_acc_count
|
| 242 |
-
error_count
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
|
| 244 |
-
avg_gen_time
|
| 245 |
-
avg_exec_time
|
| 246 |
|
| 247 |
metrics = {
|
| 248 |
"commit_hash": commit_hash,
|
|
|
|
| 13 |
LOG_DIR = Path("logs/spider_eval")
|
| 14 |
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| 15 |
|
| 16 |
+
|
| 17 |
def normalize_sql(sql: str) -> str:
|
| 18 |
# نسخه ساده؛ میتونی قویترش کنی با پارس + بازسازی
|
| 19 |
return " ".join(sql.lower().strip().split())
|
| 20 |
|
| 21 |
+
|
| 22 |
def compare_results(pred_rows, gold_rows):
|
| 23 |
if pred_rows is None or gold_rows is None:
|
| 24 |
return False
|
| 25 |
# اگر ترتیب مهم نیست
|
| 26 |
return set(pred_rows) == set(gold_rows)
|
| 27 |
|
| 28 |
+
|
| 29 |
def try_execute_sql(sql_db, sql, timeout: float = None):
|
| 30 |
start = time.time()
|
| 31 |
try:
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
return None, time.time() - start, str(e)
|
| 36 |
|
| 37 |
+
|
| 38 |
def exact_match_structural(sql_pred: str, sql_gold: str) -> bool:
|
| 39 |
try:
|
| 40 |
ast_pred = parse_one(sql_pred)
|
|
|
|
| 58 |
norm_gold = normalize_ast(ast_gold)
|
| 59 |
return norm_prd == norm_gold
|
| 60 |
|
| 61 |
+
|
| 62 |
def get_git_commit_hash() -> str:
|
| 63 |
try:
|
| 64 |
+
out = (
|
| 65 |
+
subprocess.check_output(["git", "rev-parse", "HEAD"])
|
| 66 |
+
.strip()
|
| 67 |
+
.decode("ascii")
|
| 68 |
+
)
|
| 69 |
return out
|
| 70 |
except Exception:
|
| 71 |
return "UNKNOWN"
|
| 72 |
|
| 73 |
+
|
| 74 |
FORBIDDEN_NODES = (
|
| 75 |
exp.Insert,
|
| 76 |
exp.Delete,
|
|
|
|
| 82 |
exp.Create,
|
| 83 |
)
|
| 84 |
|
| 85 |
+
|
| 86 |
def is_safe_sql(sql: str, dialect: str | None = None) -> bool:
|
| 87 |
try:
|
| 88 |
ast = parse_one(sql, read=dialect)
|
|
|
|
| 95 |
return False
|
| 96 |
return True
|
| 97 |
|
| 98 |
+
|
| 99 |
def run_eval(split="dev", limit=100, resume=True, sleep_time: float = 0.01):
|
| 100 |
data = load_spider_sqlite(split)
|
| 101 |
if len(data) < limit:
|
|
|
|
| 106 |
commit_hash = get_git_commit_hash()
|
| 107 |
start_ts = int(time.time())
|
| 108 |
|
| 109 |
+
pred_txt = LOG_DIR / f"{split}_pred_{start_ts}.txt"
|
| 110 |
+
gold_txt = LOG_DIR / f"{split}_gold_{start_ts}.txt"
|
| 111 |
results_fn = LOG_DIR / f"{split}_results_{start_ts}.jsonl"
|
| 112 |
metrics_fn = LOG_DIR / f"{split}_metrics_{start_ts}.json"
|
| 113 |
|
|
|
|
| 124 |
pass
|
| 125 |
|
| 126 |
write_header = not results_fn.exists()
|
| 127 |
+
with (
|
| 128 |
+
results_fn.open("a", encoding="utf-8") as fout,
|
| 129 |
+
pred_txt.open("a", encoding="utf-8") as fpred,
|
| 130 |
+
gold_txt.open("a", encoding="utf-8") as fgold,
|
| 131 |
+
):
|
| 132 |
if write_header:
|
| 133 |
header = {
|
| 134 |
"commit_hash": commit_hash,
|
|
|
|
| 241 |
if sleep_time > 0:
|
| 242 |
time.sleep(sleep_time)
|
| 243 |
|
| 244 |
+
valid = [
|
| 245 |
+
r
|
| 246 |
+
for r in agg
|
| 247 |
+
if (not r.get("safe_check_failed", False)) and r.get("gold_error") is None
|
| 248 |
+
]
|
| 249 |
total_valid = len(valid)
|
| 250 |
total_all = len(agg)
|
| 251 |
if total_valid == 0:
|
| 252 |
print("No valid examples to compute metrics")
|
| 253 |
return
|
| 254 |
|
| 255 |
+
em_count = sum(1 for r in valid if r["exact_match"])
|
| 256 |
em_struct_count = sum(1 for r in valid if r["exact_match_structural"])
|
| 257 |
+
exec_acc_count = sum(1 for r in valid if r["execution_accuracy"])
|
| 258 |
+
error_count = sum(
|
| 259 |
+
1
|
| 260 |
+
for r in agg
|
| 261 |
+
if (r.get("error") is not None) and (not r.get("safe_check_failed", False))
|
| 262 |
+
)
|
| 263 |
safe_fail_count = sum(1 for r in agg if r.get("safe_check_failed", False))
|
| 264 |
+
avg_gen_time = sum(r["gen_time"] for r in valid) / total_valid
|
| 265 |
+
avg_exec_time = sum(r["exec_time"] for r in valid) / total_valid
|
| 266 |
|
| 267 |
metrics = {
|
| 268 |
"commit_hash": commit_hash,
|
benchmarks/run.py
CHANGED
|
@@ -20,6 +20,7 @@ from nl2sql.repair import Repair
|
|
| 20 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 21 |
from adapters.llm.openai_provider import OpenAIProvider
|
| 22 |
|
|
|
|
| 23 |
# ---- fallbacks: Dummy LLM (so it runs without API keys)
|
| 24 |
class DummyLLM:
|
| 25 |
provider_id = "dummy-llm"
|
|
@@ -28,7 +29,14 @@ class DummyLLM:
|
|
| 28 |
text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
|
| 29 |
return text, 0, 0, 0.0
|
| 30 |
|
| 31 |
-
def generate_sql(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# naive demo SQL (so pipeline flows end-to-end)
|
| 33 |
sql = "SELECT 1 AS one;"
|
| 34 |
rationale = "Demo SQL from DummyLLM"
|
|
@@ -43,12 +51,15 @@ def ensure_demo_db(path: Path) -> None:
|
|
| 43 |
if path.exists():
|
| 44 |
return
|
| 45 |
import sqlite3
|
|
|
|
| 46 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
con = sqlite3.connect(path)
|
| 48 |
cur = con.cursor()
|
| 49 |
cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
|
| 50 |
-
cur.executemany(
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
con.commit()
|
| 53 |
con.close()
|
| 54 |
|
|
@@ -86,7 +97,7 @@ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
|
|
| 86 |
for q in queries:
|
| 87 |
t0 = time.perf_counter()
|
| 88 |
r = pipeline.run(user_query=q, schema_preview=schema_preview)
|
| 89 |
-
latency_ms = (time.perf_counter()-t0)*1000
|
| 90 |
ok = (not r.get("ambiguous")) and ("error" not in r)
|
| 91 |
|
| 92 |
traces = r.get("traces", [])
|
|
@@ -97,15 +108,19 @@ def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
|
|
| 97 |
except Exception:
|
| 98 |
pass
|
| 99 |
|
| 100 |
-
results.append(
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
outfile.parent.mkdir(parents=True, exist_ok=True)
|
| 111 |
with open(outfile, "w") as f:
|
|
@@ -118,10 +133,14 @@ def main():
|
|
| 118 |
parser = argparse.ArgumentParser()
|
| 119 |
parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
|
| 120 |
parser.add_argument("--db", default="data/bench_demo.db")
|
| 121 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
args = parser.parse_args()
|
| 123 |
|
| 124 |
-
ROOT = Path(__file__).resolve().parents[1]
|
| 125 |
outfile = (ROOT / args.outfile).resolve()
|
| 126 |
db_path = (ROOT / args.db).resolve()
|
| 127 |
|
|
|
|
| 20 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 21 |
from adapters.llm.openai_provider import OpenAIProvider
|
| 22 |
|
| 23 |
+
|
| 24 |
# ---- fallbacks: Dummy LLM (so it runs without API keys)
|
| 25 |
class DummyLLM:
|
| 26 |
provider_id = "dummy-llm"
|
|
|
|
| 29 |
text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
|
| 30 |
return text, 0, 0, 0.0
|
| 31 |
|
| 32 |
+
def generate_sql(
|
| 33 |
+
self,
|
| 34 |
+
*,
|
| 35 |
+
user_query: str,
|
| 36 |
+
schema_preview: str,
|
| 37 |
+
plan_text: str,
|
| 38 |
+
clarify_answers=None,
|
| 39 |
+
):
|
| 40 |
# naive demo SQL (so pipeline flows end-to-end)
|
| 41 |
sql = "SELECT 1 AS one;"
|
| 42 |
rationale = "Demo SQL from DummyLLM"
|
|
|
|
| 51 |
if path.exists():
|
| 52 |
return
|
| 53 |
import sqlite3
|
| 54 |
+
|
| 55 |
path.parent.mkdir(parents=True, exist_ok=True)
|
| 56 |
con = sqlite3.connect(path)
|
| 57 |
cur = con.cursor()
|
| 58 |
cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
|
| 59 |
+
cur.executemany(
|
| 60 |
+
"INSERT INTO users(id,name,spend) VALUES(?,?,?)",
|
| 61 |
+
[(1, "Alice", 120.5), (2, "Bob", 80.0), (3, "Carol", 155.0)],
|
| 62 |
+
)
|
| 63 |
con.commit()
|
| 64 |
con.close()
|
| 65 |
|
|
|
|
| 97 |
for q in queries:
|
| 98 |
t0 = time.perf_counter()
|
| 99 |
r = pipeline.run(user_query=q, schema_preview=schema_preview)
|
| 100 |
+
latency_ms = (time.perf_counter() - t0) * 1000
|
| 101 |
ok = (not r.get("ambiguous")) and ("error" not in r)
|
| 102 |
|
| 103 |
traces = r.get("traces", [])
|
|
|
|
| 108 |
except Exception:
|
| 109 |
pass
|
| 110 |
|
| 111 |
+
results.append(
|
| 112 |
+
{
|
| 113 |
+
"query": q,
|
| 114 |
+
"exec_acc": 1.0 if ok else 0.0,
|
| 115 |
+
"safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
|
| 116 |
+
"latency_ms": latency_ms,
|
| 117 |
+
"cost_usd": cost_sum,
|
| 118 |
+
"repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
|
| 119 |
+
"provider": pipeline.generator.llm.provider_id
|
| 120 |
+
if hasattr(pipeline.generator, "llm")
|
| 121 |
+
else "unknown",
|
| 122 |
+
}
|
| 123 |
+
)
|
| 124 |
|
| 125 |
outfile.parent.mkdir(parents=True, exist_ok=True)
|
| 126 |
with open(outfile, "w") as f:
|
|
|
|
| 133 |
parser = argparse.ArgumentParser()
|
| 134 |
parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
|
| 135 |
parser.add_argument("--db", default="data/bench_demo.db")
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--use-openai",
|
| 138 |
+
action="store_true",
|
| 139 |
+
help="Use OpenAI provider if API key present",
|
| 140 |
+
)
|
| 141 |
args = parser.parse_args()
|
| 142 |
|
| 143 |
+
ROOT = Path(__file__).resolve().parents[1] # project root
|
| 144 |
outfile = (ROOT / args.outfile).resolve()
|
| 145 |
db_path = (ROOT / args.db).resolve()
|
| 146 |
|
benchmarks/spider_loader.py
CHANGED
|
@@ -4,9 +4,8 @@ from dataclasses import dataclass
|
|
| 4 |
from typing import List, Optional
|
| 5 |
import os
|
| 6 |
|
| 7 |
-
SPIDER_ROOT = pathlib.Path(
|
| 8 |
-
|
| 9 |
-
)
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class SpiderItem:
|
|
@@ -15,7 +14,10 @@ class SpiderItem:
|
|
| 15 |
gold_sql: str
|
| 16 |
db_path: pathlib.Path
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
|
| 20 |
json_path = SPIDER_ROOT / fn
|
| 21 |
try:
|
|
@@ -23,7 +25,6 @@ def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[
|
|
| 23 |
except Exception as e:
|
| 24 |
raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
|
| 25 |
|
| 26 |
-
|
| 27 |
out: list[SpiderItem] = []
|
| 28 |
for ex in items[: (limit or len(items))]:
|
| 29 |
db_id = ex["db_id"]
|
|
@@ -35,14 +36,16 @@ def load_spider_sqlite(split: str = "dev", limit: Optional[int] = None) -> List[
|
|
| 35 |
db_id=db_id,
|
| 36 |
question=ex["question"],
|
| 37 |
gold_sql=ex["query"],
|
| 38 |
-
db_path=db_path
|
| 39 |
)
|
| 40 |
)
|
| 41 |
return out
|
| 42 |
|
| 43 |
|
| 44 |
-
def open_readonly_connection(
|
|
|
|
|
|
|
| 45 |
uri = f"file:{db_path}?mode=ro&uri=true"
|
| 46 |
conn = sqlite3.connect(uri, uri=True, timeout=timeout)
|
| 47 |
conn.row_factory = sqlite3.Row
|
| 48 |
-
return conn
|
|
|
|
| 4 |
from typing import List, Optional
|
| 5 |
import os
|
| 6 |
|
| 7 |
+
SPIDER_ROOT = pathlib.Path(os.getenv("SPIDER_ROOT", "data/spider"))
|
| 8 |
+
|
|
|
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class SpiderItem:
|
|
|
|
| 14 |
gold_sql: str
|
| 15 |
db_path: pathlib.Path
|
| 16 |
|
| 17 |
+
|
| 18 |
+
def load_spider_sqlite(
|
| 19 |
+
split: str = "dev", limit: Optional[int] = None
|
| 20 |
+
) -> List[SpiderItem]:
|
| 21 |
fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
|
| 22 |
json_path = SPIDER_ROOT / fn
|
| 23 |
try:
|
|
|
|
| 25 |
except Exception as e:
|
| 26 |
raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
|
| 27 |
|
|
|
|
| 28 |
out: list[SpiderItem] = []
|
| 29 |
for ex in items[: (limit or len(items))]:
|
| 30 |
db_id = ex["db_id"]
|
|
|
|
| 36 |
db_id=db_id,
|
| 37 |
question=ex["question"],
|
| 38 |
gold_sql=ex["query"],
|
| 39 |
+
db_path=db_path,
|
| 40 |
)
|
| 41 |
)
|
| 42 |
return out
|
| 43 |
|
| 44 |
|
| 45 |
+
def open_readonly_connection(
|
| 46 |
+
db_path: pathlib.Path, timeout: float = 5.0
|
| 47 |
+
) -> sqlite3.Connection:
|
| 48 |
uri = f"file:{db_path}?mode=ro&uri=true"
|
| 49 |
conn = sqlite3.connect(uri, uri=True, timeout=timeout)
|
| 50 |
conn.row_factory = sqlite3.Row
|
| 51 |
+
return conn
|
config.py
CHANGED
|
@@ -4,12 +4,15 @@ from dotenv import load_dotenv
|
|
| 4 |
load_dotenv()
|
| 5 |
|
| 6 |
|
| 7 |
-
def get_env_var(
|
|
|
|
|
|
|
| 8 |
val = os.getenv(name, default)
|
| 9 |
if required and not val:
|
| 10 |
raise ValueError(f"Missing required environment variable: {name}")
|
| 11 |
return val
|
| 12 |
|
|
|
|
| 13 |
proxy_key = os.getenv("PROXY_API_KEY")
|
| 14 |
proxy_base = os.getenv("PROXY_BASE_URL")
|
| 15 |
openai_key = os.getenv("OPENAI_API_KEY")
|
|
@@ -17,7 +20,9 @@ openai_base = os.getenv("OPENAI_BASE_URL")
|
|
| 17 |
|
| 18 |
api_key = proxy_key or openai_key
|
| 19 |
if not api_key:
|
| 20 |
-
raise ValueError(
|
|
|
|
|
|
|
| 21 |
|
| 22 |
base_url = proxy_base or openai_base or "https://api.openai.com/v1"
|
| 23 |
|
|
@@ -33,11 +38,24 @@ LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # or gpt-4o, gpt-4o-mini
|
|
| 33 |
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
|
| 34 |
|
| 35 |
FORBIDDEN_KEYWORDS = {
|
| 36 |
-
"ATTACH",
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"DETACH",
|
| 42 |
}
|
| 43 |
FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
|
|
|
|
| 4 |
load_dotenv()
|
| 5 |
|
| 6 |
|
| 7 |
+
def get_env_var(
|
| 8 |
+
name: str, required: bool = True, default: str | None = None
|
| 9 |
+
) -> str | None:
|
| 10 |
val = os.getenv(name, default)
|
| 11 |
if required and not val:
|
| 12 |
raise ValueError(f"Missing required environment variable: {name}")
|
| 13 |
return val
|
| 14 |
|
| 15 |
+
|
| 16 |
proxy_key = os.getenv("PROXY_API_KEY")
|
| 17 |
proxy_base = os.getenv("PROXY_BASE_URL")
|
| 18 |
openai_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
| 20 |
|
| 21 |
api_key = proxy_key or openai_key
|
| 22 |
if not api_key:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"Missing API key: set PROXY_API_KEY or OPENAI_API_KEY in environment/secrets."
|
| 25 |
+
)
|
| 26 |
|
| 27 |
base_url = proxy_base or openai_base or "https://api.openai.com/v1"
|
| 28 |
|
|
|
|
| 38 |
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
|
| 39 |
|
| 40 |
FORBIDDEN_KEYWORDS = {
|
| 41 |
+
"ATTACH",
|
| 42 |
+
"PRAGMA",
|
| 43 |
+
"CREATE",
|
| 44 |
+
"DROP",
|
| 45 |
+
"ALTER",
|
| 46 |
+
"VACUUM",
|
| 47 |
+
"REINDEX",
|
| 48 |
+
"TRIGGER",
|
| 49 |
+
"INSERT",
|
| 50 |
+
"UPDATE",
|
| 51 |
+
"DELETE",
|
| 52 |
+
"REPLACE",
|
| 53 |
+
"GRANT",
|
| 54 |
+
"REVOKE",
|
| 55 |
+
"BEGIN",
|
| 56 |
+
"END",
|
| 57 |
+
"COMMIT",
|
| 58 |
+
"ROLLBACK",
|
| 59 |
"DETACH",
|
| 60 |
}
|
| 61 |
FORBIDDEN_TABLES = {"sqlite_master", "sqlite_temp_master"}
|
nl2sql/ambiguity_detector.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
import re
|
| 2 |
from typing import List
|
| 3 |
|
|
|
|
| 4 |
class AmbiguityDetector:
|
| 5 |
"""Lightweight AmbiSQL-style ambiguity detection."""
|
| 6 |
|
| 7 |
AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
|
| 8 |
|
| 9 |
-
def detect(self, query:str, schema_preview: str) -> list[str]:
|
| 10 |
hits = []
|
| 11 |
q_lower = query.lower()
|
| 12 |
for term in self.AMBIGUOUS_TERMS:
|
| 13 |
if re.search(rf"\b{term}\b", q_lower):
|
| 14 |
hits.append(f"The term '{term}' is ambiguous in this query.'")
|
| 15 |
|
| 16 |
-
return hits
|
|
|
|
| 1 |
import re
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
+
|
| 5 |
class AmbiguityDetector:
|
| 6 |
"""Lightweight AmbiSQL-style ambiguity detection."""
|
| 7 |
|
| 8 |
AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
|
| 9 |
|
| 10 |
+
def detect(self, query: str, schema_preview: str) -> list[str]:
|
| 11 |
hits = []
|
| 12 |
q_lower = query.lower()
|
| 13 |
for term in self.AMBIGUOUS_TERMS:
|
| 14 |
if re.search(rf"\b{term}\b", q_lower):
|
| 15 |
hits.append(f"The term '{term}' is ambiguous in this query.'")
|
| 16 |
|
| 17 |
+
return hits
|
nl2sql/executor.py
CHANGED
|
@@ -2,6 +2,7 @@ import time
|
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
|
|
|
| 5 |
class Executor:
|
| 6 |
name = "executor"
|
| 7 |
|
|
@@ -12,10 +13,18 @@ class Executor:
|
|
| 12 |
t0 = time.perf_counter()
|
| 13 |
try:
|
| 14 |
rows, cols = self.db.execute(sql)
|
| 15 |
-
trace = StageTrace(
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
except Exception as e:
|
| 19 |
-
trace = StageTrace(
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
|
|
|
| 2 |
from nl2sql.types import StageResult, StageTrace
|
| 3 |
from adapters.db.base import DBAdapter
|
| 4 |
|
| 5 |
+
|
| 6 |
class Executor:
|
| 7 |
name = "executor"
|
| 8 |
|
|
|
|
| 13 |
t0 = time.perf_counter()
|
| 14 |
try:
|
| 15 |
rows, cols = self.db.execute(sql)
|
| 16 |
+
trace = StageTrace(
|
| 17 |
+
stage=self.name,
|
| 18 |
+
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 19 |
+
notes={"row_count": len(rows), "col_count": len(cols)},
|
| 20 |
+
)
|
| 21 |
+
return StageResult(
|
| 22 |
+
ok=True, data={"rows": rows, "columns": cols}, trace=trace
|
| 23 |
+
)
|
| 24 |
except Exception as e:
|
| 25 |
+
trace = StageTrace(
|
| 26 |
+
stage=self.name,
|
| 27 |
+
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 28 |
+
notes={"error": str(e)},
|
| 29 |
+
)
|
| 30 |
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
nl2sql/generator.py
CHANGED
|
@@ -4,34 +4,48 @@ from typing import Optional, Dict, Any
|
|
| 4 |
from nl2sql.types import StageResult, StageTrace
|
| 5 |
from adapters.llm.base import LLMProvider
|
| 6 |
|
|
|
|
| 7 |
class Generator:
|
| 8 |
name = "generator"
|
| 9 |
|
| 10 |
def __init__(self, llm: LLMProvider) -> None:
|
| 11 |
self.llm = llm
|
| 12 |
|
| 13 |
-
def run(
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
t0 = time.perf_counter()
|
| 16 |
try:
|
| 17 |
res = self.llm.generate_sql(
|
| 18 |
user_query=user_query,
|
| 19 |
schema_preview=schema_preview,
|
| 20 |
plan_text=plan_text,
|
| 21 |
-
clarify_answers=clarify_answers or {}
|
| 22 |
)
|
| 23 |
except Exception as e:
|
| 24 |
return StageResult(ok=False, error=[f"Generator failed: {e}"])
|
| 25 |
|
| 26 |
# Expect a 5-tuple
|
| 27 |
if not isinstance(res, tuple) or len(res) != 5:
|
| 28 |
-
return StageResult(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
sql, rationale, t_in, t_out, cost = res
|
| 31 |
|
| 32 |
# Type/shape checks
|
| 33 |
if not isinstance(sql, str) or not sql.strip():
|
| 34 |
-
return StageResult(
|
|
|
|
|
|
|
| 35 |
if not sql.lower().lstrip().startswith("select"):
|
| 36 |
return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
|
| 37 |
|
|
@@ -45,5 +59,6 @@ class Generator:
|
|
| 45 |
notes={"rationale_len": len(rationale)},
|
| 46 |
)
|
| 47 |
|
| 48 |
-
return StageResult(
|
| 49 |
-
|
|
|
|
|
|
| 4 |
from nl2sql.types import StageResult, StageTrace
|
| 5 |
from adapters.llm.base import LLMProvider
|
| 6 |
|
| 7 |
+
|
| 8 |
class Generator:
|
| 9 |
name = "generator"
|
| 10 |
|
| 11 |
def __init__(self, llm: LLMProvider) -> None:
|
| 12 |
self.llm = llm
|
| 13 |
|
| 14 |
+
def run(
|
| 15 |
+
self,
|
| 16 |
+
*,
|
| 17 |
+
user_query: str,
|
| 18 |
+
schema_preview: str,
|
| 19 |
+
plan_text: str,
|
| 20 |
+
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 21 |
+
) -> StageResult:
|
| 22 |
t0 = time.perf_counter()
|
| 23 |
try:
|
| 24 |
res = self.llm.generate_sql(
|
| 25 |
user_query=user_query,
|
| 26 |
schema_preview=schema_preview,
|
| 27 |
plan_text=plan_text,
|
| 28 |
+
clarify_answers=clarify_answers or {},
|
| 29 |
)
|
| 30 |
except Exception as e:
|
| 31 |
return StageResult(ok=False, error=[f"Generator failed: {e}"])
|
| 32 |
|
| 33 |
# Expect a 5-tuple
|
| 34 |
if not isinstance(res, tuple) or len(res) != 5:
|
| 35 |
+
return StageResult(
|
| 36 |
+
ok=False,
|
| 37 |
+
error=[
|
| 38 |
+
"Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"
|
| 39 |
+
],
|
| 40 |
+
)
|
| 41 |
|
| 42 |
sql, rationale, t_in, t_out, cost = res
|
| 43 |
|
| 44 |
# Type/shape checks
|
| 45 |
if not isinstance(sql, str) or not sql.strip():
|
| 46 |
+
return StageResult(
|
| 47 |
+
ok=False, error=["Generator produced empty or non-string SQL"]
|
| 48 |
+
)
|
| 49 |
if not sql.lower().lstrip().startswith("select"):
|
| 50 |
return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
|
| 51 |
|
|
|
|
| 59 |
notes={"rationale_len": len(rationale)},
|
| 60 |
)
|
| 61 |
|
| 62 |
+
return StageResult(
|
| 63 |
+
ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
|
| 64 |
+
)
|
nl2sql/pipeline.py
CHANGED
|
@@ -17,14 +17,17 @@ class Pipeline:
|
|
| 17 |
All stages return structured traces and errors but final result is JSON-safe dict.
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
def __init__(
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
self.detector = detector
|
| 29 |
self.planner = planner
|
| 30 |
self.generator = generator
|
|
@@ -59,8 +62,13 @@ class Pipeline:
|
|
| 59 |
return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
|
| 60 |
|
| 61 |
# ------------------------------------------------------------
|
| 62 |
-
def run(
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
"""
|
| 65 |
Always returns:
|
| 66 |
{
|
|
@@ -86,26 +94,45 @@ class Pipeline:
|
|
| 86 |
"error": False,
|
| 87 |
"details": [f"Ambiguities found: {len(questions)}"],
|
| 88 |
"questions": questions,
|
| 89 |
-
"traces": []
|
| 90 |
}
|
| 91 |
except Exception as e:
|
| 92 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
# --- 2) planner
|
| 95 |
-
r_plan = self._safe_stage(
|
|
|
|
|
|
|
| 96 |
traces.extend(self._trace_list(r_plan))
|
| 97 |
if not r_plan.ok:
|
| 98 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# --- 3) generator
|
| 101 |
-
r_gen = self._safe_stage(
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
traces.extend(self._trace_list(r_gen))
|
| 107 |
if not r_gen.ok:
|
| 108 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
sql = r_gen.data.get("sql")
|
| 110 |
rationale = r_gen.data.get("rationale")
|
| 111 |
|
|
@@ -113,7 +140,12 @@ class Pipeline:
|
|
| 113 |
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 114 |
traces.extend(self._trace_list(r_safe))
|
| 115 |
if not r_safe.ok:
|
| 116 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# --- 5) executor
|
| 119 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
|
@@ -129,10 +161,12 @@ class Pipeline:
|
|
| 129 |
# --- 7) repair loop if verification failed
|
| 130 |
if not verified:
|
| 131 |
for attempt in range(2):
|
| 132 |
-
r_fix = self._safe_stage(
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
traces.extend(self._trace_list(r_fix))
|
| 137 |
if not r_fix.ok:
|
| 138 |
break
|
|
|
|
| 17 |
All stages return structured traces and errors but final result is JSON-safe dict.
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
*,
|
| 23 |
+
detector: AmbiguityDetector,
|
| 24 |
+
planner: Planner,
|
| 25 |
+
generator: Generator,
|
| 26 |
+
safety: Safety,
|
| 27 |
+
executor: Executor,
|
| 28 |
+
verifier: Verifier,
|
| 29 |
+
repair: Repair,
|
| 30 |
+
):
|
| 31 |
self.detector = detector
|
| 32 |
self.planner = planner
|
| 33 |
self.generator = generator
|
|
|
|
| 62 |
return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
|
| 63 |
|
| 64 |
# ------------------------------------------------------------
|
| 65 |
+
def run(
|
| 66 |
+
self,
|
| 67 |
+
*,
|
| 68 |
+
user_query: str,
|
| 69 |
+
schema_preview: str,
|
| 70 |
+
clarify_answers: Optional[Dict[str, Any]] = None,
|
| 71 |
+
) -> Dict[str, Any]:
|
| 72 |
"""
|
| 73 |
Always returns:
|
| 74 |
{
|
|
|
|
| 94 |
"error": False,
|
| 95 |
"details": [f"Ambiguities found: {len(questions)}"],
|
| 96 |
"questions": questions,
|
| 97 |
+
"traces": [],
|
| 98 |
}
|
| 99 |
except Exception as e:
|
| 100 |
+
return {
|
| 101 |
+
"ambiguous": True,
|
| 102 |
+
"error": True,
|
| 103 |
+
"details": [f"Detector failed: {e}"],
|
| 104 |
+
"traces": [],
|
| 105 |
+
}
|
| 106 |
|
| 107 |
# --- 2) planner
|
| 108 |
+
r_plan = self._safe_stage(
|
| 109 |
+
self.planner.run, user_query=user_query, schema_preview=schema_preview
|
| 110 |
+
)
|
| 111 |
traces.extend(self._trace_list(r_plan))
|
| 112 |
if not r_plan.ok:
|
| 113 |
+
return {
|
| 114 |
+
"ambiguous": False,
|
| 115 |
+
"error": True,
|
| 116 |
+
"details": r_plan.errors,
|
| 117 |
+
"traces": traces,
|
| 118 |
+
}
|
| 119 |
|
| 120 |
# --- 3) generator
|
| 121 |
+
r_gen = self._safe_stage(
|
| 122 |
+
self.generator.run,
|
| 123 |
+
user_query=user_query,
|
| 124 |
+
schema_preview=schema_preview,
|
| 125 |
+
plan_text=r_plan.data.get("plan"),
|
| 126 |
+
clarify_answers=clarify_answers or {},
|
| 127 |
+
)
|
| 128 |
traces.extend(self._trace_list(r_gen))
|
| 129 |
if not r_gen.ok:
|
| 130 |
+
return {
|
| 131 |
+
"ambiguous": False,
|
| 132 |
+
"error": True,
|
| 133 |
+
"details": r_gen.errors,
|
| 134 |
+
"traces": traces,
|
| 135 |
+
}
|
| 136 |
sql = r_gen.data.get("sql")
|
| 137 |
rationale = r_gen.data.get("rationale")
|
| 138 |
|
|
|
|
| 140 |
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 141 |
traces.extend(self._trace_list(r_safe))
|
| 142 |
if not r_safe.ok:
|
| 143 |
+
return {
|
| 144 |
+
"ambiguous": False,
|
| 145 |
+
"error": True,
|
| 146 |
+
"details": r_safe.errors,
|
| 147 |
+
"traces": traces,
|
| 148 |
+
}
|
| 149 |
|
| 150 |
# --- 5) executor
|
| 151 |
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
|
|
|
| 161 |
# --- 7) repair loop if verification failed
|
| 162 |
if not verified:
|
| 163 |
for attempt in range(2):
|
| 164 |
+
r_fix = self._safe_stage(
|
| 165 |
+
self.repair.run,
|
| 166 |
+
sql=sql,
|
| 167 |
+
error_msg="; ".join(details or ["unknown"]),
|
| 168 |
+
schema_preview=schema_preview,
|
| 169 |
+
)
|
| 170 |
traces.extend(self._trace_list(r_fix))
|
| 171 |
if not r_fix.ok:
|
| 172 |
break
|
nl2sql/planner.py
CHANGED
|
@@ -3,14 +3,24 @@ import time
|
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
from adapters.llm.base import LLMProvider
|
| 5 |
|
|
|
|
| 6 |
class Planner:
|
| 7 |
name = "planner"
|
|
|
|
| 8 |
def __init__(self, llm: LLMProvider) -> None:
|
| 9 |
self.llm = llm
|
| 10 |
|
| 11 |
def run(self, *, user_query: str, schema_preview: str) -> StageResult:
|
| 12 |
t0 = time.perf_counter()
|
| 13 |
-
plan_text, t_in, t_out, cost = self.llm.plan(
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
|
|
|
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
from adapters.llm.base import LLMProvider
|
| 5 |
|
| 6 |
+
|
| 7 |
class Planner:
|
| 8 |
name = "planner"
|
| 9 |
+
|
| 10 |
def __init__(self, llm: LLMProvider) -> None:
|
| 11 |
self.llm = llm
|
| 12 |
|
| 13 |
def run(self, *, user_query: str, schema_preview: str) -> StageResult:
|
| 14 |
t0 = time.perf_counter()
|
| 15 |
+
plan_text, t_in, t_out, cost = self.llm.plan(
|
| 16 |
+
user_query=user_query, schema_preview=schema_preview
|
| 17 |
+
)
|
| 18 |
+
trace = StageTrace(
|
| 19 |
+
stage=self.name,
|
| 20 |
+
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 21 |
+
token_in=t_in,
|
| 22 |
+
token_out=t_out,
|
| 23 |
+
cost_usd=cost,
|
| 24 |
+
notes={"len_plan": len(plan_text)},
|
| 25 |
+
)
|
| 26 |
return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
|
nl2sql/repair.py
CHANGED
|
@@ -14,16 +14,26 @@ When repairing:
|
|
| 14 |
Return only the corrected SQL.
|
| 15 |
"""
|
| 16 |
|
|
|
|
| 17 |
class Repair:
|
| 18 |
name = "repair"
|
|
|
|
| 19 |
def __init__(self, llm: LLMProvider):
|
| 20 |
self.llm = llm
|
| 21 |
|
| 22 |
-
def run(self, sql:str, error_msg: str, schema_preview: str) -> StageResult:
|
| 23 |
t0 = time.perf_counter()
|
| 24 |
-
fixed_sql, t_in, t_out, cost = self.llm.repair(
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
|
|
|
|
| 14 |
Return only the corrected SQL.
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
|
| 18 |
class Repair:
|
| 19 |
name = "repair"
|
| 20 |
+
|
| 21 |
def __init__(self, llm: LLMProvider):
|
| 22 |
self.llm = llm
|
| 23 |
|
| 24 |
+
def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
|
| 25 |
t0 = time.perf_counter()
|
| 26 |
+
fixed_sql, t_in, t_out, cost = self.llm.repair(
|
| 27 |
+
sql=sql,
|
| 28 |
+
error_msg=f"{GUIDELINES}\n\n{error_msg}",
|
| 29 |
+
schema_preview=schema_preview,
|
| 30 |
+
)
|
| 31 |
+
trace = StageTrace(
|
| 32 |
+
stage=self.name,
|
| 33 |
+
duration_ms=(time.perf_counter() - t0) * 1000,
|
| 34 |
+
token_in=t_in,
|
| 35 |
+
token_out=t_out,
|
| 36 |
+
cost_usd=cost,
|
| 37 |
+
notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)},
|
| 38 |
+
)
|
| 39 |
return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
|
nl2sql/safety.py
CHANGED
|
@@ -4,7 +4,7 @@ from nl2sql.types import StageResult, StageTrace
|
|
| 4 |
|
| 5 |
# --- Regex utils ---
|
| 6 |
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
|
| 7 |
-
_COMMENT_LINE
|
| 8 |
# string literals (single & double quotes), allow escaped quotes
|
| 9 |
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
|
| 10 |
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
|
|
@@ -18,20 +18,24 @@ _FORBIDDEN = re.compile(
|
|
| 18 |
# allow: SELECT ... or WITH <cte...> SELECT ...
|
| 19 |
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
|
| 20 |
|
|
|
|
| 21 |
def _strip_comments(s: str) -> str:
|
| 22 |
s = _COMMENT_BLOCK.sub(" ", s)
|
| 23 |
s = _COMMENT_LINE.sub(" ", s)
|
| 24 |
return s
|
| 25 |
|
|
|
|
| 26 |
def _mask_strings(s: str) -> str:
|
| 27 |
s = _STRING_SINGLE.sub("'X'", s)
|
| 28 |
s = _STRING_DOUBLE.sub('"X"', s)
|
| 29 |
return s
|
| 30 |
|
|
|
|
| 31 |
def _split_statements(s: str) -> list[str]:
|
| 32 |
parts = [p.strip() for p in s.split(";")]
|
| 33 |
return [p for p in parts if p]
|
| 34 |
|
|
|
|
| 35 |
class Safety:
|
| 36 |
name = "safety"
|
| 37 |
|
|
@@ -46,7 +50,9 @@ class Safety:
|
|
| 46 |
return StageResult(
|
| 47 |
ok=False,
|
| 48 |
error=["Multiple statements detected"],
|
| 49 |
-
trace=StageTrace(
|
|
|
|
|
|
|
| 50 |
)
|
| 51 |
|
| 52 |
body = stmts[0]
|
|
@@ -55,14 +61,18 @@ class Safety:
|
|
| 55 |
return StageResult(
|
| 56 |
ok=False,
|
| 57 |
error=["Forbidden keyword detected"],
|
| 58 |
-
trace=StageTrace(
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
if not _ALLOW_SELECT.match(body):
|
| 62 |
return StageResult(
|
| 63 |
ok=False,
|
| 64 |
error=["Non-SELECT statement"],
|
| 65 |
-
trace=StageTrace(
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
return StageResult(
|
|
@@ -71,5 +81,7 @@ class Safety:
|
|
| 71 |
"sql": sql.strip(),
|
| 72 |
"rationale": "Statement validated as SELECT-only (strings/comments ignored).",
|
| 73 |
},
|
| 74 |
-
trace=StageTrace(
|
|
|
|
|
|
|
| 75 |
)
|
|
|
|
| 4 |
|
| 5 |
# --- Regex utils ---
|
| 6 |
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
|
| 7 |
+
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
|
| 8 |
# string literals (single & double quotes), allow escaped quotes
|
| 9 |
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
|
| 10 |
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
|
|
|
|
| 18 |
# allow: SELECT ... or WITH <cte...> SELECT ...
|
| 19 |
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
|
| 20 |
|
| 21 |
+
|
| 22 |
def _strip_comments(s: str) -> str:
|
| 23 |
s = _COMMENT_BLOCK.sub(" ", s)
|
| 24 |
s = _COMMENT_LINE.sub(" ", s)
|
| 25 |
return s
|
| 26 |
|
| 27 |
+
|
| 28 |
def _mask_strings(s: str) -> str:
|
| 29 |
s = _STRING_SINGLE.sub("'X'", s)
|
| 30 |
s = _STRING_DOUBLE.sub('"X"', s)
|
| 31 |
return s
|
| 32 |
|
| 33 |
+
|
| 34 |
def _split_statements(s: str) -> list[str]:
|
| 35 |
parts = [p.strip() for p in s.split(";")]
|
| 36 |
return [p for p in parts if p]
|
| 37 |
|
| 38 |
+
|
| 39 |
class Safety:
|
| 40 |
name = "safety"
|
| 41 |
|
|
|
|
| 50 |
return StageResult(
|
| 51 |
ok=False,
|
| 52 |
error=["Multiple statements detected"],
|
| 53 |
+
trace=StageTrace(
|
| 54 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 55 |
+
),
|
| 56 |
)
|
| 57 |
|
| 58 |
body = stmts[0]
|
|
|
|
| 61 |
return StageResult(
|
| 62 |
ok=False,
|
| 63 |
error=["Forbidden keyword detected"],
|
| 64 |
+
trace=StageTrace(
|
| 65 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 66 |
+
),
|
| 67 |
)
|
| 68 |
|
| 69 |
if not _ALLOW_SELECT.match(body):
|
| 70 |
return StageResult(
|
| 71 |
ok=False,
|
| 72 |
error=["Non-SELECT statement"],
|
| 73 |
+
trace=StageTrace(
|
| 74 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 75 |
+
),
|
| 76 |
)
|
| 77 |
|
| 78 |
return StageResult(
|
|
|
|
| 81 |
"sql": sql.strip(),
|
| 82 |
"rationale": "Statement validated as SELECT-only (strings/comments ignored).",
|
| 83 |
},
|
| 84 |
+
trace=StageTrace(
|
| 85 |
+
stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
|
| 86 |
+
),
|
| 87 |
)
|
nl2sql/stubs.py
CHANGED
|
@@ -1,31 +1,37 @@
|
|
| 1 |
from nl2sql.types import StageResult, StageTrace
|
| 2 |
|
|
|
|
| 3 |
class NoOpExecutor:
|
| 4 |
name = "executor"
|
|
|
|
| 5 |
def run(self, sql: str) -> StageResult:
|
| 6 |
# pretend success, return empty result set
|
| 7 |
return StageResult(
|
| 8 |
ok=True,
|
| 9 |
data={"rows": [], "columns": []},
|
| 10 |
-
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 11 |
)
|
| 12 |
|
|
|
|
| 13 |
class NoOpVerifier:
|
| 14 |
name = "verifier"
|
|
|
|
| 15 |
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 16 |
# always verified for legacy tests
|
| 17 |
return StageResult(
|
| 18 |
ok=True,
|
| 19 |
data={"verified": True},
|
| 20 |
-
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 21 |
)
|
| 22 |
|
|
|
|
| 23 |
class NoOpRepair:
|
| 24 |
name = "repair"
|
|
|
|
| 25 |
def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
|
| 26 |
# return original SQL unchanged
|
| 27 |
return StageResult(
|
| 28 |
ok=True,
|
| 29 |
data={"sql": sql},
|
| 30 |
-
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 31 |
)
|
|
|
|
| 1 |
from nl2sql.types import StageResult, StageTrace
|
| 2 |
|
| 3 |
+
|
| 4 |
class NoOpExecutor:
|
| 5 |
name = "executor"
|
| 6 |
+
|
| 7 |
def run(self, sql: str) -> StageResult:
|
| 8 |
# pretend success, return empty result set
|
| 9 |
return StageResult(
|
| 10 |
ok=True,
|
| 11 |
data={"rows": [], "columns": []},
|
| 12 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
|
| 13 |
)
|
| 14 |
|
| 15 |
+
|
| 16 |
class NoOpVerifier:
|
| 17 |
name = "verifier"
|
| 18 |
+
|
| 19 |
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 20 |
# always verified for legacy tests
|
| 21 |
return StageResult(
|
| 22 |
ok=True,
|
| 23 |
data={"verified": True},
|
| 24 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
|
| 25 |
)
|
| 26 |
|
| 27 |
+
|
| 28 |
class NoOpRepair:
|
| 29 |
name = "repair"
|
| 30 |
+
|
| 31 |
def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
|
| 32 |
# return original SQL unchanged
|
| 33 |
return StageResult(
|
| 34 |
ok=True,
|
| 35 |
data={"sql": sql},
|
| 36 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True}),
|
| 37 |
)
|
nl2sql/types.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any, Dict, Optional, List
|
| 3 |
|
|
|
|
| 4 |
@dataclass(frozen=True)
|
| 5 |
class StageTrace:
|
| 6 |
stage: str
|
|
@@ -10,6 +11,7 @@ class StageTrace:
|
|
| 10 |
token_out: Optional[int] = None
|
| 11 |
cost_usd: Optional[float] = None
|
| 12 |
|
|
|
|
| 13 |
@dataclass(frozen=True)
|
| 14 |
class StageResult:
|
| 15 |
ok: bool
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any, Dict, Optional, List
|
| 3 |
|
| 4 |
+
|
| 5 |
@dataclass(frozen=True)
|
| 6 |
class StageTrace:
|
| 7 |
stage: str
|
|
|
|
| 11 |
token_out: Optional[int] = None
|
| 12 |
cost_usd: Optional[float] = None
|
| 13 |
|
| 14 |
+
|
| 15 |
@dataclass(frozen=True)
|
| 16 |
class StageResult:
|
| 17 |
ok: bool
|
nl2sql/verifier.py
CHANGED
|
@@ -2,15 +2,20 @@ import sqlglot
|
|
| 2 |
from sqlglot import expressions as exp
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
|
|
|
|
| 5 |
class Verifier:
|
| 6 |
name = "verifier"
|
| 7 |
|
| 8 |
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 9 |
if not exec_result.ok:
|
| 10 |
-
return StageResult(
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Rule 1: check SELECT / GROUP consistency
|
| 16 |
issues = []
|
|
@@ -25,9 +30,16 @@ class Verifier:
|
|
| 25 |
issues.append(f"Parse error during verification: {e}")
|
| 26 |
|
| 27 |
if issues:
|
| 28 |
-
return StageResult(
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from sqlglot import expressions as exp
|
| 3 |
from nl2sql.types import StageResult, StageTrace
|
| 4 |
|
| 5 |
+
|
| 6 |
class Verifier:
|
| 7 |
name = "verifier"
|
| 8 |
|
| 9 |
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 10 |
if not exec_result.ok:
|
| 11 |
+
return StageResult(
|
| 12 |
+
ok=False,
|
| 13 |
+
data=None,
|
| 14 |
+
trace=StageTrace(
|
| 15 |
+
stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
|
| 16 |
+
),
|
| 17 |
+
error=exec_result.errors,
|
| 18 |
+
)
|
| 19 |
|
| 20 |
# Rule 1: check SELECT / GROUP consistency
|
| 21 |
issues = []
|
|
|
|
| 30 |
issues.append(f"Parse error during verification: {e}")
|
| 31 |
|
| 32 |
if issues:
|
| 33 |
+
return StageResult(
|
| 34 |
+
ok=False,
|
| 35 |
+
data=None,
|
| 36 |
+
trace=StageTrace(
|
| 37 |
+
stage=self.name, duration_ms=0, notes={"issues": issues}
|
| 38 |
+
),
|
| 39 |
+
error=issues,
|
| 40 |
+
)
|
| 41 |
+
return StageResult(
|
| 42 |
+
ok=True,
|
| 43 |
+
data={"verified": True},
|
| 44 |
+
trace=StageTrace(stage=self.name, duration_ms=0),
|
| 45 |
+
)
|
tests/conftest.py
CHANGED
|
@@ -4,4 +4,4 @@ from dotenv import load_dotenv
|
|
| 4 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
ENV_PATH = os.path.join(ROOT_DIR, ".env")
|
| 6 |
|
| 7 |
-
load_dotenv(dotenv_path=ENV_PATH)
|
|
|
|
| 4 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
ENV_PATH = os.path.join(ROOT_DIR, ".env")
|
| 6 |
|
| 7 |
+
load_dotenv(dotenv_path=ENV_PATH)
|
tests/test_ambiguity.py
CHANGED
|
@@ -2,18 +2,23 @@ from nl2sql.ambiguity_detector import AmbiguityDetector
|
|
| 2 |
from nl2sql.types import StageResult
|
| 3 |
from app.routers import nl2sql
|
| 4 |
|
|
|
|
| 5 |
def test_detects_ambiguous_terms():
|
| 6 |
det = AmbiguityDetector()
|
| 7 |
res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
|
| 8 |
assert len(res) >= 1
|
| 9 |
assert "recent" in res[0].lower()
|
| 10 |
|
|
|
|
| 11 |
def test_not_false_positive():
|
| 12 |
det = AmbiguityDetector()
|
| 13 |
res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
|
| 14 |
assert res == []
|
| 15 |
|
|
|
|
| 16 |
def test_ambiguity_response():
|
| 17 |
-
fake_result = StageResult(
|
|
|
|
|
|
|
| 18 |
response = nl2sql._to_dict(fake_result.data)
|
| 19 |
-
assert response["ambiguous"] is True
|
|
|
|
| 2 |
from nl2sql.types import StageResult
|
| 3 |
from app.routers import nl2sql
|
| 4 |
|
| 5 |
+
|
| 6 |
def test_detects_ambiguous_terms():
|
| 7 |
det = AmbiguityDetector()
|
| 8 |
res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
|
| 9 |
assert len(res) >= 1
|
| 10 |
assert "recent" in res[0].lower()
|
| 11 |
|
| 12 |
+
|
| 13 |
def test_not_false_positive():
|
| 14 |
det = AmbiguityDetector()
|
| 15 |
res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
|
| 16 |
assert res == []
|
| 17 |
|
| 18 |
+
|
| 19 |
def test_ambiguity_response():
|
| 20 |
+
fake_result = StageResult(
|
| 21 |
+
ok=True, data={"ambiguous": True, "questions": ["Clarify column?"]}
|
| 22 |
+
)
|
| 23 |
response = nl2sql._to_dict(fake_result.data)
|
| 24 |
+
assert response["ambiguous"] is True
|
tests/test_executor.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
from nl2sql.executor import Executor
|
| 2 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 3 |
|
|
|
|
| 4 |
def test_executor_runs_select(tmp_path):
|
| 5 |
db_path = tmp_path / "test.db"
|
| 6 |
import sqlite3
|
|
|
|
| 7 |
conn = sqlite3.connect(db_path)
|
| 8 |
conn.execute("CREATE TABLE users(id INT, name TEXT);")
|
| 9 |
conn.execute("INSERT INTO users VALUES (1, 'Alice');")
|
|
|
|
| 1 |
from nl2sql.executor import Executor
|
| 2 |
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 3 |
|
| 4 |
+
|
| 5 |
def test_executor_runs_select(tmp_path):
|
| 6 |
db_path = tmp_path / "test.db"
|
| 7 |
import sqlite3
|
| 8 |
+
|
| 9 |
conn = sqlite3.connect(db_path)
|
| 10 |
conn.execute("CREATE TABLE users(id INT, name TEXT);")
|
| 11 |
conn.execute("INSERT INTO users VALUES (1, 'Alice');")
|
tests/test_generator.py
CHANGED
|
@@ -5,6 +5,7 @@ from nl2sql.types import StageResult
|
|
| 5 |
|
| 6 |
# --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
|
| 7 |
|
|
|
|
| 8 |
class LLM_OK:
|
| 9 |
def generate_sql(self, **kwargs):
|
| 10 |
# contract: (sql, rationale, t_in, t_out, cost)
|
|
@@ -37,11 +38,12 @@ class LLM_CONTRACT_SHORT:
|
|
| 37 |
|
| 38 |
# --- Parametrized negative cases ----------------------------------------------
|
| 39 |
|
|
|
|
| 40 |
@pytest.mark.parametrize(
|
| 41 |
"llm, err_keyword",
|
| 42 |
[
|
| 43 |
-
(LLM_EMPTY_SQL(), "empty"),
|
| 44 |
-
(LLM_NON_SELECT(), "non-select"),
|
| 45 |
(LLM_CONTRACT_NONE(), "contract violation"),
|
| 46 |
(LLM_CONTRACT_SHORT(), "contract violation"),
|
| 47 |
],
|
|
@@ -52,7 +54,7 @@ def test_generator_errors_do_not_create_trace(llm, err_keyword):
|
|
| 52 |
user_query="show all singers",
|
| 53 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 54 |
plan_text="-- plan --",
|
| 55 |
-
clarify_answers={}
|
| 56 |
)
|
| 57 |
assert isinstance(r, StageResult)
|
| 58 |
assert r.ok is False
|
|
@@ -65,13 +67,14 @@ def test_generator_errors_do_not_create_trace(llm, err_keyword):
|
|
| 65 |
|
| 66 |
# --- Positive case (success) ---------------------------------------------------
|
| 67 |
|
|
|
|
| 68 |
def test_generator_success_has_valid_trace_and_data():
|
| 69 |
gen = Generator(llm=LLM_OK())
|
| 70 |
r = gen.run(
|
| 71 |
user_query="show all singers",
|
| 72 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 73 |
plan_text="-- plan --",
|
| 74 |
-
clarify_answers={}
|
| 75 |
)
|
| 76 |
|
| 77 |
# Basic success checks
|
|
|
|
| 5 |
|
| 6 |
# --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
|
| 7 |
|
| 8 |
+
|
| 9 |
class LLM_OK:
|
| 10 |
def generate_sql(self, **kwargs):
|
| 11 |
# contract: (sql, rationale, t_in, t_out, cost)
|
|
|
|
| 38 |
|
| 39 |
# --- Parametrized negative cases ----------------------------------------------
|
| 40 |
|
| 41 |
+
|
| 42 |
@pytest.mark.parametrize(
|
| 43 |
"llm, err_keyword",
|
| 44 |
[
|
| 45 |
+
(LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
|
| 46 |
+
(LLM_NON_SELECT(), "non-select"), # generated non-SELECT
|
| 47 |
(LLM_CONTRACT_NONE(), "contract violation"),
|
| 48 |
(LLM_CONTRACT_SHORT(), "contract violation"),
|
| 49 |
],
|
|
|
|
| 54 |
user_query="show all singers",
|
| 55 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 56 |
plan_text="-- plan --",
|
| 57 |
+
clarify_answers={},
|
| 58 |
)
|
| 59 |
assert isinstance(r, StageResult)
|
| 60 |
assert r.ok is False
|
|
|
|
| 67 |
|
| 68 |
# --- Positive case (success) ---------------------------------------------------
|
| 69 |
|
| 70 |
+
|
| 71 |
def test_generator_success_has_valid_trace_and_data():
|
| 72 |
gen = Generator(llm=LLM_OK())
|
| 73 |
r = gen.run(
|
| 74 |
user_query="show all singers",
|
| 75 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 76 |
plan_text="-- plan --",
|
| 77 |
+
clarify_answers={},
|
| 78 |
)
|
| 79 |
|
| 80 |
# Basic success checks
|
tests/test_nl2sql_router.py
CHANGED
|
@@ -9,8 +9,10 @@ client = TestClient(app)
|
|
| 9 |
def fake_trace(stage: str):
|
| 10 |
return StageTrace(stage=stage, duration_ms=10.0)
|
| 11 |
|
|
|
|
| 12 |
path = app.url_path_for("nl2sql_handler")
|
| 13 |
|
|
|
|
| 14 |
# --- 1) Clarify / ambiguity case ---------------------------------------------
|
| 15 |
def test_ambiguity_route(monkeypatch):
|
| 16 |
from app.routers import nl2sql
|
|
@@ -47,7 +49,9 @@ def test_error_route(monkeypatch):
|
|
| 47 |
from app.routers import nl2sql
|
| 48 |
|
| 49 |
def fake_run(*args, **kwargs):
|
| 50 |
-
return StageResult(
|
|
|
|
|
|
|
| 51 |
|
| 52 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 53 |
|
|
|
|
| 9 |
def fake_trace(stage: str):
|
| 10 |
return StageTrace(stage=stage, duration_ms=10.0)
|
| 11 |
|
| 12 |
+
|
| 13 |
path = app.url_path_for("nl2sql_handler")
|
| 14 |
|
| 15 |
+
|
| 16 |
# --- 1) Clarify / ambiguity case ---------------------------------------------
|
| 17 |
def test_ambiguity_route(monkeypatch):
|
| 18 |
from app.routers import nl2sql
|
|
|
|
| 49 |
from app.routers import nl2sql
|
| 50 |
|
| 51 |
def fake_run(*args, **kwargs):
|
| 52 |
+
return StageResult(
|
| 53 |
+
ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]}
|
| 54 |
+
)
|
| 55 |
|
| 56 |
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 57 |
|
tests/test_openai_provider.py
CHANGED
|
@@ -6,21 +6,23 @@ from adapters.llm.openai_provider import OpenAIProvider
|
|
| 6 |
# Helper class to fake the completion object returned by OpenAI SDK
|
| 7 |
class FakeCompletion:
|
| 8 |
def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
|
| 9 |
-
self.choices = [
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
# --- Case 1: clean valid JSON --------------------------------------------------
|
| 17 |
def test_generate_sql_valid_json(monkeypatch):
|
| 18 |
provider = OpenAIProvider()
|
| 19 |
|
| 20 |
-
fake_content = json.dumps(
|
| 21 |
-
"sql": "SELECT * FROM singer;",
|
| 22 |
-
|
| 23 |
-
})
|
| 24 |
fake_completion = FakeCompletion(fake_content)
|
| 25 |
|
| 26 |
# Monkeypatch client.chat.completions.create
|
|
@@ -33,7 +35,7 @@ def test_generate_sql_valid_json(monkeypatch):
|
|
| 33 |
user_query="show all singers",
|
| 34 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 35 |
plan_text="-- plan --",
|
| 36 |
-
clarify_answers={}
|
| 37 |
)
|
| 38 |
|
| 39 |
assert sql.strip().lower().startswith("select")
|
|
@@ -48,7 +50,7 @@ def test_generate_sql_recover_from_partial_json(monkeypatch):
|
|
| 48 |
provider = OpenAIProvider()
|
| 49 |
|
| 50 |
# invalid JSON with text around it
|
| 51 |
-
fake_content =
|
| 52 |
fake_completion = FakeCompletion(fake_content)
|
| 53 |
|
| 54 |
def fake_create(*args, **kwargs):
|
|
@@ -59,7 +61,7 @@ def test_generate_sql_recover_from_partial_json(monkeypatch):
|
|
| 59 |
sql, rationale, *_ = provider.generate_sql(
|
| 60 |
user_query="show all users",
|
| 61 |
schema_preview="CREATE TABLE users(id int, name text);",
|
| 62 |
-
plan_text="-- plan --"
|
| 63 |
)
|
| 64 |
|
| 65 |
assert sql.lower().startswith("select")
|
|
@@ -83,5 +85,5 @@ def test_generate_sql_invalid_json(monkeypatch):
|
|
| 83 |
provider.generate_sql(
|
| 84 |
user_query="show X",
|
| 85 |
schema_preview="CREATE TABLE t(id int);",
|
| 86 |
-
plan_text="-- plan --"
|
| 87 |
)
|
|
|
|
| 6 |
# Helper class to fake the completion object returned by OpenAI SDK
|
| 7 |
class FakeCompletion:
|
| 8 |
def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
|
| 9 |
+
self.choices = [
|
| 10 |
+
type("Choice", (), {"message": type("Msg", (), {"content": content})})
|
| 11 |
+
]
|
| 12 |
+
self.usage = type(
|
| 13 |
+
"Usage",
|
| 14 |
+
(),
|
| 15 |
+
{"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens},
|
| 16 |
+
)
|
| 17 |
|
| 18 |
|
| 19 |
# --- Case 1: clean valid JSON --------------------------------------------------
|
| 20 |
def test_generate_sql_valid_json(monkeypatch):
|
| 21 |
provider = OpenAIProvider()
|
| 22 |
|
| 23 |
+
fake_content = json.dumps(
|
| 24 |
+
{"sql": "SELECT * FROM singer;", "rationale": "List all singers."}
|
| 25 |
+
)
|
|
|
|
| 26 |
fake_completion = FakeCompletion(fake_content)
|
| 27 |
|
| 28 |
# Monkeypatch client.chat.completions.create
|
|
|
|
| 35 |
user_query="show all singers",
|
| 36 |
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 37 |
plan_text="-- plan --",
|
| 38 |
+
clarify_answers={},
|
| 39 |
)
|
| 40 |
|
| 41 |
assert sql.strip().lower().startswith("select")
|
|
|
|
| 50 |
provider = OpenAIProvider()
|
| 51 |
|
| 52 |
# invalid JSON with text around it
|
| 53 |
+
fake_content = 'Here is the result:\n{ "sql": "SELECT * FROM users;", "rationale": "list users" }\nThanks!'
|
| 54 |
fake_completion = FakeCompletion(fake_content)
|
| 55 |
|
| 56 |
def fake_create(*args, **kwargs):
|
|
|
|
| 61 |
sql, rationale, *_ = provider.generate_sql(
|
| 62 |
user_query="show all users",
|
| 63 |
schema_preview="CREATE TABLE users(id int, name text);",
|
| 64 |
+
plan_text="-- plan --",
|
| 65 |
)
|
| 66 |
|
| 67 |
assert sql.lower().startswith("select")
|
|
|
|
| 85 |
provider.generate_sql(
|
| 86 |
user_query="show X",
|
| 87 |
schema_preview="CREATE TABLE t(id int);",
|
| 88 |
+
plan_text="-- plan --",
|
| 89 |
)
|
tests/test_pipeline_integration.py
CHANGED
|
@@ -5,8 +5,10 @@ from nl2sql.types import StageResult, StageTrace
|
|
| 5 |
|
| 6 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 7 |
|
|
|
|
| 8 |
class DummyDetector:
|
| 9 |
"""Simulates ambiguity detector stage."""
|
|
|
|
| 10 |
def __init__(self, ambiguous=False):
|
| 11 |
self.ambiguous = ambiguous
|
| 12 |
|
|
@@ -17,6 +19,7 @@ class DummyDetector:
|
|
| 17 |
|
| 18 |
class DummyPlanner:
|
| 19 |
"""Simulates planner stage."""
|
|
|
|
| 20 |
def run(self, *, user_query, schema_preview):
|
| 21 |
trace = StageTrace(stage="planner", duration_ms=1.0)
|
| 22 |
if "fail_plan" in user_query:
|
|
@@ -26,17 +29,21 @@ class DummyPlanner:
|
|
| 26 |
|
| 27 |
class DummyGenerator:
|
| 28 |
"""Simulates generator stage."""
|
|
|
|
| 29 |
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 30 |
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 31 |
if "fail_gen" in user_query:
|
| 32 |
return StageResult(ok=False, error=["Generator failed"], trace=trace)
|
| 33 |
sql = "SELECT * FROM singer;"
|
| 34 |
rationale = "List all singers."
|
| 35 |
-
return StageResult(
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class DummySafety:
|
| 39 |
"""Simulates safety stage."""
|
|
|
|
| 40 |
def check(self, sql):
|
| 41 |
trace = StageTrace(stage="safety", duration_ms=1.0)
|
| 42 |
if "DROP" in sql.upper():
|
|
@@ -50,12 +57,12 @@ def test_pipeline_success():
|
|
| 50 |
detector=DummyDetector(ambiguous=False),
|
| 51 |
planner=DummyPlanner(),
|
| 52 |
generator=DummyGenerator(),
|
| 53 |
-
safety=DummySafety()
|
| 54 |
)
|
| 55 |
|
| 56 |
r = pipeline.run(
|
| 57 |
user_query="show all singers",
|
| 58 |
-
schema_preview="CREATE TABLE singer(id int, name text);"
|
| 59 |
)
|
| 60 |
|
| 61 |
assert isinstance(r, StageResult)
|
|
@@ -73,13 +80,10 @@ def test_pipeline_ambiguity():
|
|
| 73 |
detector=DummyDetector(ambiguous=True),
|
| 74 |
planner=DummyPlanner(),
|
| 75 |
generator=DummyGenerator(),
|
| 76 |
-
safety=DummySafety()
|
| 77 |
)
|
| 78 |
|
| 79 |
-
r = pipeline.run(
|
| 80 |
-
user_query="show data",
|
| 81 |
-
schema_preview="CREATE TABLE x(id int);"
|
| 82 |
-
)
|
| 83 |
|
| 84 |
assert isinstance(r, StageResult)
|
| 85 |
assert r.ok is True
|
|
@@ -93,11 +97,10 @@ def test_pipeline_plan_fail():
|
|
| 93 |
detector=DummyDetector(),
|
| 94 |
planner=DummyPlanner(),
|
| 95 |
generator=DummyGenerator(),
|
| 96 |
-
safety=DummySafety()
|
| 97 |
)
|
| 98 |
r = pipeline.run(
|
| 99 |
-
user_query="fail_plan",
|
| 100 |
-
schema_preview="CREATE TABLE singer(id int);"
|
| 101 |
)
|
| 102 |
assert isinstance(r, StageResult)
|
| 103 |
assert r.ok is False
|
|
@@ -110,11 +113,10 @@ def test_pipeline_gen_fail():
|
|
| 110 |
detector=DummyDetector(),
|
| 111 |
planner=DummyPlanner(),
|
| 112 |
generator=DummyGenerator(),
|
| 113 |
-
safety=DummySafety()
|
| 114 |
)
|
| 115 |
r = pipeline.run(
|
| 116 |
-
user_query="fail_gen",
|
| 117 |
-
schema_preview="CREATE TABLE singer(id int);"
|
| 118 |
)
|
| 119 |
assert r.ok is False
|
| 120 |
assert "Generator failed" in " ".join(r.error or [])
|
|
@@ -126,17 +128,18 @@ def test_pipeline_safety_fail():
|
|
| 126 |
def run(self, **kw):
|
| 127 |
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 128 |
# Generate a DROP TABLE → unsafe
|
| 129 |
-
return StageResult(
|
|
|
|
|
|
|
| 130 |
|
| 131 |
pipeline = Pipeline(
|
| 132 |
detector=DummyDetector(),
|
| 133 |
planner=DummyPlanner(),
|
| 134 |
generator=UnsafeGen(),
|
| 135 |
-
safety=DummySafety()
|
| 136 |
)
|
| 137 |
r = pipeline.run(
|
| 138 |
-
user_query="drop something",
|
| 139 |
-
schema_preview="CREATE TABLE x(id int);"
|
| 140 |
)
|
| 141 |
assert r.ok is False
|
| 142 |
assert "unsafe" in " ".join(r.error or []).lower()
|
|
|
|
| 5 |
|
| 6 |
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 7 |
|
| 8 |
+
|
| 9 |
class DummyDetector:
|
| 10 |
"""Simulates ambiguity detector stage."""
|
| 11 |
+
|
| 12 |
def __init__(self, ambiguous=False):
|
| 13 |
self.ambiguous = ambiguous
|
| 14 |
|
|
|
|
| 19 |
|
| 20 |
class DummyPlanner:
|
| 21 |
"""Simulates planner stage."""
|
| 22 |
+
|
| 23 |
def run(self, *, user_query, schema_preview):
|
| 24 |
trace = StageTrace(stage="planner", duration_ms=1.0)
|
| 25 |
if "fail_plan" in user_query:
|
|
|
|
| 29 |
|
| 30 |
class DummyGenerator:
|
| 31 |
"""Simulates generator stage."""
|
| 32 |
+
|
| 33 |
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 34 |
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 35 |
if "fail_gen" in user_query:
|
| 36 |
return StageResult(ok=False, error=["Generator failed"], trace=trace)
|
| 37 |
sql = "SELECT * FROM singer;"
|
| 38 |
rationale = "List all singers."
|
| 39 |
+
return StageResult(
|
| 40 |
+
ok=True, data={"sql": sql, "rationale": rationale}, trace=trace
|
| 41 |
+
)
|
| 42 |
|
| 43 |
|
| 44 |
class DummySafety:
|
| 45 |
"""Simulates safety stage."""
|
| 46 |
+
|
| 47 |
def check(self, sql):
|
| 48 |
trace = StageTrace(stage="safety", duration_ms=1.0)
|
| 49 |
if "DROP" in sql.upper():
|
|
|
|
| 57 |
detector=DummyDetector(ambiguous=False),
|
| 58 |
planner=DummyPlanner(),
|
| 59 |
generator=DummyGenerator(),
|
| 60 |
+
safety=DummySafety(),
|
| 61 |
)
|
| 62 |
|
| 63 |
r = pipeline.run(
|
| 64 |
user_query="show all singers",
|
| 65 |
+
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 66 |
)
|
| 67 |
|
| 68 |
assert isinstance(r, StageResult)
|
|
|
|
| 80 |
detector=DummyDetector(ambiguous=True),
|
| 81 |
planner=DummyPlanner(),
|
| 82 |
generator=DummyGenerator(),
|
| 83 |
+
safety=DummySafety(),
|
| 84 |
)
|
| 85 |
|
| 86 |
+
r = pipeline.run(user_query="show data", schema_preview="CREATE TABLE x(id int);")
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
assert isinstance(r, StageResult)
|
| 89 |
assert r.ok is True
|
|
|
|
| 97 |
detector=DummyDetector(),
|
| 98 |
planner=DummyPlanner(),
|
| 99 |
generator=DummyGenerator(),
|
| 100 |
+
safety=DummySafety(),
|
| 101 |
)
|
| 102 |
r = pipeline.run(
|
| 103 |
+
user_query="fail_plan", schema_preview="CREATE TABLE singer(id int);"
|
|
|
|
| 104 |
)
|
| 105 |
assert isinstance(r, StageResult)
|
| 106 |
assert r.ok is False
|
|
|
|
| 113 |
detector=DummyDetector(),
|
| 114 |
planner=DummyPlanner(),
|
| 115 |
generator=DummyGenerator(),
|
| 116 |
+
safety=DummySafety(),
|
| 117 |
)
|
| 118 |
r = pipeline.run(
|
| 119 |
+
user_query="fail_gen", schema_preview="CREATE TABLE singer(id int);"
|
|
|
|
| 120 |
)
|
| 121 |
assert r.ok is False
|
| 122 |
assert "Generator failed" in " ".join(r.error or [])
|
|
|
|
| 128 |
def run(self, **kw):
|
| 129 |
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 130 |
# Generate a DROP TABLE → unsafe
|
| 131 |
+
return StageResult(
|
| 132 |
+
ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace
|
| 133 |
+
)
|
| 134 |
|
| 135 |
pipeline = Pipeline(
|
| 136 |
detector=DummyDetector(),
|
| 137 |
planner=DummyPlanner(),
|
| 138 |
generator=UnsafeGen(),
|
| 139 |
+
safety=DummySafety(),
|
| 140 |
)
|
| 141 |
r = pipeline.run(
|
| 142 |
+
user_query="drop something", schema_preview="CREATE TABLE x(id int);"
|
|
|
|
| 143 |
)
|
| 144 |
assert r.ok is False
|
| 145 |
assert "unsafe" in " ".join(r.error or []).lower()
|
tests/test_safety.py
CHANGED
|
@@ -2,7 +2,6 @@ from nl2sql.safety import Safety
|
|
| 2 |
import pytest
|
| 3 |
|
| 4 |
|
| 5 |
-
|
| 6 |
def test_safety_allows_select():
|
| 7 |
s = Safety()
|
| 8 |
result = s.check("SELECT * FROM users;")
|
|
@@ -10,6 +9,7 @@ def test_safety_allows_select():
|
|
| 10 |
assert "sql" in result.data
|
| 11 |
assert result.trace.stage == "safety"
|
| 12 |
|
|
|
|
| 13 |
def test_safety_allows_with_select_cte():
|
| 14 |
s = Safety()
|
| 15 |
sql = """
|
|
@@ -21,12 +21,14 @@ def test_safety_allows_with_select_cte():
|
|
| 21 |
r = s.check(sql)
|
| 22 |
assert r.ok
|
| 23 |
|
|
|
|
| 24 |
def test_safety_allows_select_with_comments_and_newlines():
|
| 25 |
s = Safety()
|
| 26 |
sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
|
| 27 |
r = s.check(sql)
|
| 28 |
assert r.ok
|
| 29 |
|
|
|
|
| 30 |
def test_safety_allows_keywords_inside_string_literals():
|
| 31 |
s = Safety()
|
| 32 |
sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
|
|
@@ -40,32 +42,39 @@ def test_safety_blocks_delete():
|
|
| 40 |
assert not result.ok
|
| 41 |
assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def test_safety_blocks_forbidden_statements(sql):
|
| 53 |
s = Safety()
|
| 54 |
res = s.check(sql)
|
| 55 |
assert not res.ok
|
| 56 |
|
|
|
|
| 57 |
def test_safety_blocks_stacked_delete_after_select():
|
| 58 |
s = Safety()
|
| 59 |
sql = "SELECT * FROM users; DELETE FROM users;"
|
| 60 |
r = s.check(sql)
|
| 61 |
assert not r.ok
|
| 62 |
|
|
|
|
| 63 |
def test_safety_blocks_stacked_delete_with_spaces():
|
| 64 |
s = Safety()
|
| 65 |
sql = "SELECT * FROM users ; \n DELETE users;"
|
| 66 |
r = s.check(sql)
|
| 67 |
assert not r.ok
|
| 68 |
|
|
|
|
| 69 |
def test_safety_blocks_delete_inside_cte():
|
| 70 |
s = Safety()
|
| 71 |
sql = """
|
|
@@ -75,26 +84,35 @@ def test_safety_blocks_delete_inside_cte():
|
|
| 75 |
r = s.check(sql)
|
| 76 |
assert not r.ok
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
"
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def test_safety_blocks_comment_obfuscation(sql):
|
| 84 |
s = Safety()
|
| 85 |
r = s.check(sql)
|
| 86 |
assert not r.ok
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
"
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def test_safety_blocks_forbidden_case_and_spacing(sql):
|
| 94 |
s = Safety()
|
| 95 |
r = s.check(sql)
|
| 96 |
assert not r.ok
|
| 97 |
|
|
|
|
| 98 |
def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
|
| 99 |
s = Safety()
|
| 100 |
sql = "SELECT 1; -- now do something bad\n"
|
|
|
|
| 2 |
import pytest
|
| 3 |
|
| 4 |
|
|
|
|
| 5 |
def test_safety_allows_select():
|
| 6 |
s = Safety()
|
| 7 |
result = s.check("SELECT * FROM users;")
|
|
|
|
| 9 |
assert "sql" in result.data
|
| 10 |
assert result.trace.stage == "safety"
|
| 11 |
|
| 12 |
+
|
| 13 |
def test_safety_allows_with_select_cte():
|
| 14 |
s = Safety()
|
| 15 |
sql = """
|
|
|
|
| 21 |
r = s.check(sql)
|
| 22 |
assert r.ok
|
| 23 |
|
| 24 |
+
|
| 25 |
def test_safety_allows_select_with_comments_and_newlines():
|
| 26 |
s = Safety()
|
| 27 |
sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
|
| 28 |
r = s.check(sql)
|
| 29 |
assert r.ok
|
| 30 |
|
| 31 |
+
|
| 32 |
def test_safety_allows_keywords_inside_string_literals():
|
| 33 |
s = Safety()
|
| 34 |
sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
|
|
|
|
| 42 |
assert not result.ok
|
| 43 |
assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
|
| 44 |
|
| 45 |
+
|
| 46 |
+
@pytest.mark.parametrize(
|
| 47 |
+
"sql",
|
| 48 |
+
[
|
| 49 |
+
"UPDATE users SET name='X' WHERE id=1;",
|
| 50 |
+
"INSERT INTO users(id) VALUES (1);",
|
| 51 |
+
"DROP TABLE users;",
|
| 52 |
+
"CREATE TABLE x(id INT);",
|
| 53 |
+
"ALTER TABLE users ADD COLUMN x INT;",
|
| 54 |
+
"ATTACH DATABASE 'hack.db' AS h;",
|
| 55 |
+
"PRAGMA journal_mode=WAL;",
|
| 56 |
+
],
|
| 57 |
+
)
|
| 58 |
def test_safety_blocks_forbidden_statements(sql):
|
| 59 |
s = Safety()
|
| 60 |
res = s.check(sql)
|
| 61 |
assert not res.ok
|
| 62 |
|
| 63 |
+
|
| 64 |
def test_safety_blocks_stacked_delete_after_select():
|
| 65 |
s = Safety()
|
| 66 |
sql = "SELECT * FROM users; DELETE FROM users;"
|
| 67 |
r = s.check(sql)
|
| 68 |
assert not r.ok
|
| 69 |
|
| 70 |
+
|
| 71 |
def test_safety_blocks_stacked_delete_with_spaces():
|
| 72 |
s = Safety()
|
| 73 |
sql = "SELECT * FROM users ; \n DELETE users;"
|
| 74 |
r = s.check(sql)
|
| 75 |
assert not r.ok
|
| 76 |
|
| 77 |
+
|
| 78 |
def test_safety_blocks_delete_inside_cte():
|
| 79 |
s = Safety()
|
| 80 |
sql = """
|
|
|
|
| 84 |
r = s.check(sql)
|
| 85 |
assert not r.ok
|
| 86 |
|
| 87 |
+
|
| 88 |
+
@pytest.mark.parametrize(
|
| 89 |
+
"sql",
|
| 90 |
+
[
|
| 91 |
+
"/*D*/ROP TABLE users;",
|
| 92 |
+
"PR/*x*/AGMA journal_mode=WAL;",
|
| 93 |
+
"AL/* comment */TER TABLE x ADD COLUMN y INT;",
|
| 94 |
+
],
|
| 95 |
+
)
|
| 96 |
def test_safety_blocks_comment_obfuscation(sql):
|
| 97 |
s = Safety()
|
| 98 |
r = s.check(sql)
|
| 99 |
assert not r.ok
|
| 100 |
|
| 101 |
+
|
| 102 |
+
@pytest.mark.parametrize(
|
| 103 |
+
"sql",
|
| 104 |
+
[
|
| 105 |
+
"pragma journal_mode=WAL;", # lower-case
|
| 106 |
+
" PRAGMA user_version = 5 ; ",
|
| 107 |
+
"\nATTACH DATABASE 'hack.db' AS h;",
|
| 108 |
+
],
|
| 109 |
+
)
|
| 110 |
def test_safety_blocks_forbidden_case_and_spacing(sql):
|
| 111 |
s = Safety()
|
| 112 |
r = s.check(sql)
|
| 113 |
assert not r.ok
|
| 114 |
|
| 115 |
+
|
| 116 |
def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
|
| 117 |
s = Safety()
|
| 118 |
sql = "SELECT 1; -- now do something bad\n"
|
tests/test_stage_types.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
| 1 |
from nl2sql.types import StageResult, StageTrace
|
| 2 |
|
|
|
|
| 3 |
def test_error_response():
|
| 4 |
r = StageResult(ok=False, error=["Syntax error"])
|
| 5 |
assert not r.ok
|
| 6 |
assert r.error == ["Syntax error"]
|
| 7 |
|
|
|
|
| 8 |
def test_trace_dataclass_structure():
|
| 9 |
t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
|
| 10 |
assert t.stage == "planner"
|
| 11 |
assert isinstance(t.duration_ms, float)
|
| 12 |
assert t.token_out == 20
|
| 13 |
|
|
|
|
| 14 |
def test_stage_result_defaults():
|
| 15 |
r = StageResult(ok=True)
|
| 16 |
assert r.ok
|
|
|
|
| 1 |
from nl2sql.types import StageResult, StageTrace
|
| 2 |
|
| 3 |
+
|
| 4 |
def test_error_response():
|
| 5 |
r = StageResult(ok=False, error=["Syntax error"])
|
| 6 |
assert not r.ok
|
| 7 |
assert r.error == ["Syntax error"]
|
| 8 |
|
| 9 |
+
|
| 10 |
def test_trace_dataclass_structure():
|
| 11 |
t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
|
| 12 |
assert t.stage == "planner"
|
| 13 |
assert isinstance(t.duration_ms, float)
|
| 14 |
assert t.token_out == 20
|
| 15 |
|
| 16 |
+
|
| 17 |
def test_stage_result_defaults():
|
| 18 |
r = StageResult(ok=True)
|
| 19 |
assert r.ok
|
ui/benchmark_app.py
CHANGED
|
@@ -22,8 +22,8 @@ df = pd.DataFrame(rows)
|
|
| 22 |
st.subheader("Aggregate Metrics")
|
| 23 |
col1, col2, col3, col4 = st.columns(4)
|
| 24 |
col1.metric("Total Queries", len(df))
|
| 25 |
-
col2.metric("Execution Accuracy", f"{df['exec_acc'].mean()*100:.1f}%")
|
| 26 |
-
col3.metric("Safety Violations", f"{df['safe_fail'].mean()*100:.1f}%")
|
| 27 |
col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
|
| 28 |
|
| 29 |
# 3. Latency Distribution
|
|
@@ -33,13 +33,23 @@ st.plotly_chart(fig1, use_container_width=True)
|
|
| 33 |
|
| 34 |
# 4. Cost vs Accuracy
|
| 35 |
st.subheader("Cost vs Execution Accuracy")
|
| 36 |
-
fig2 = px.scatter(
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
st.plotly_chart(fig2, use_container_width=True)
|
| 39 |
|
| 40 |
# 5. Repair Stats
|
| 41 |
if "repair_attempts" in df.columns:
|
| 42 |
st.subheader("Repair Attempts")
|
| 43 |
-
fig3 = px.bar(
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
st.plotly_chart(fig3, use_container_width=True)
|
|
|
|
| 22 |
st.subheader("Aggregate Metrics")
|
| 23 |
col1, col2, col3, col4 = st.columns(4)
|
| 24 |
col1.metric("Total Queries", len(df))
|
| 25 |
+
col2.metric("Execution Accuracy", f"{df['exec_acc'].mean() * 100:.1f}%")
|
| 26 |
+
col3.metric("Safety Violations", f"{df['safe_fail'].mean() * 100:.1f}%")
|
| 27 |
col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
|
| 28 |
|
| 29 |
# 3. Latency Distribution
|
|
|
|
| 33 |
|
| 34 |
# 4. Cost vs Accuracy
|
| 35 |
st.subheader("Cost vs Execution Accuracy")
|
| 36 |
+
fig2 = px.scatter(
|
| 37 |
+
df,
|
| 38 |
+
x="cost_usd",
|
| 39 |
+
y="exec_acc",
|
| 40 |
+
color="provider",
|
| 41 |
+
title="Trade-off: Cost vs Accuracy",
|
| 42 |
+
hover_data=["query"],
|
| 43 |
+
)
|
| 44 |
st.plotly_chart(fig2, use_container_width=True)
|
| 45 |
|
| 46 |
# 5. Repair Stats
|
| 47 |
if "repair_attempts" in df.columns:
|
| 48 |
st.subheader("Repair Attempts")
|
| 49 |
+
fig3 = px.bar(
|
| 50 |
+
df.groupby("repair_attempts").size().reset_index(name="count"),
|
| 51 |
+
x="repair_attempts",
|
| 52 |
+
y="count",
|
| 53 |
+
title="Number of Repair Attempts per Query",
|
| 54 |
+
)
|
| 55 |
st.plotly_chart(fig3, use_container_width=True)
|