sql-agent / src /rag /schema_extractor.py
DanielRegaladoCardoso's picture
Better SQL prompt: column-name mapping rules + few-shot + distinct value hints
b0a300f verified
"""
Schema extractor for an in-memory DuckDB connection.
Reads tables/columns/types via DuckDB introspection and produces both a
structured dict and a CREATE-TABLE-style text rendering used as context
for the SQL generator model.
"""
import logging
from typing import Any, Dict, List
import duckdb
logger = logging.getLogger(__name__)
class SchemaExtractor:
"""Extract schema information from a DuckDB connection."""
def __init__(self, con: duckdb.DuckDBPyConnection) -> None:
self.con = con
def extract_full_schema(self) -> Dict[str, Any]:
tables = self._extract_tables()
return {
"tables": tables,
"summary": self._generate_summary(tables),
"create_statements": self._build_create_statements(tables),
}
def _extract_tables(self) -> List[Dict[str, Any]]:
names = [r[0] for r in self.con.execute("SHOW TABLES").fetchall()]
out: List[Dict[str, Any]] = []
for name in names:
cols = self._extract_columns(name)
row_count = self._get_row_count(name)
sample = self._get_sample(name)
out.append({
"name": name,
"columns": cols,
"row_count": row_count,
"sample": sample,
})
return out
def _extract_columns(self, table: str) -> List[Dict[str, Any]]:
rows = self.con.execute(f'DESCRIBE "{table}"').fetchall()
cols: List[Dict[str, Any]] = []
for r in rows:
name, dtype = r[0], r[1]
nullable = (r[2] != "NO") if len(r) > 2 and r[2] is not None else True
cols.append({
"name": name,
"type": dtype,
"nullable": nullable,
})
return cols
def _get_row_count(self, table: str) -> int:
try:
return self.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0]
except Exception:
return 0
def _get_sample(self, table: str, n: int = 3) -> List[Dict[str, Any]]:
try:
cur = self.con.execute(f'SELECT * FROM "{table}" LIMIT {n}')
rows = cur.fetchall()
cols = [d[0] for d in cur.description or []]
return [dict(zip(cols, row)) for row in rows]
except Exception:
return []
def _generate_summary(self, tables: List[Dict[str, Any]]) -> str:
lines = []
for t in tables:
lines.append(f"Table {t['name']} ({t['row_count']:,} rows):")
for c in t["columns"]:
lines.append(f" - {c['name']}: {c['type']}")
lines.append("")
return "\n".join(lines).rstrip()
def _build_create_statements(self, tables: List[Dict[str, Any]]) -> str:
"""Render schema as CREATE TABLE statements (best format for LLM context)."""
chunks = []
for t in tables:
cols = ",\n ".join(
f'"{c["name"]}" {c["type"]}' for c in t["columns"]
)
chunks.append(f'CREATE TABLE "{t["name"]}" (\n {cols}\n);')
return "\n\n".join(chunks)
def get_schema_text(self) -> str:
"""Schema description optimized for LLM SQL generation:
- CREATE TABLE statements (canonical column names + types)
- Sample rows (so model sees realistic values)
- For low-cardinality categorical columns: list of distinct values
(gives the model a vocabulary to match user wording against)
"""
info = self.extract_full_schema()
out = [info["create_statements"]]
for t in info["tables"]:
if t["sample"]:
out.append(f"\n-- Sample rows from {t['name']}:")
for row in t["sample"]:
out.append(f"-- {row}")
# Distinct-values hints for categorical columns
for col in t["columns"]:
hint = self._distinct_values_hint(t["name"], col)
if hint:
out.append(hint)
return "\n".join(out)
def _distinct_values_hint(
self, table: str, col: Dict[str, Any], threshold: int = 25
) -> str | None:
"""If a column is categorical with <=threshold distinct values,
list them so the LLM can map user wording to actual values."""
dtype = (col.get("type") or "").upper()
# Only worth doing for string-ish columns
if not any(t in dtype for t in ("VARCHAR", "STRING", "TEXT", "CHAR")):
return None
try:
rows = self.con.execute(
f'SELECT DISTINCT "{col["name"]}" FROM "{table}" '
f'WHERE "{col["name"]}" IS NOT NULL '
f'LIMIT {threshold + 1}'
).fetchall()
except Exception:
return None
if len(rows) > threshold:
return None # too many distinct values, not categorical
if len(rows) <= 1:
return None # not informative
values = ", ".join(repr(r[0]) for r in rows)
return f"-- {table}.{col['name']} distinct values: {values}"