Spaces:
Running on Zero
Running on Zero
File size: 5,128 Bytes
a067ada b0a300f a067ada b0a300f a067ada b0a300f a067ada b0a300f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
Schema extractor for an in-memory DuckDB connection.
Reads tables/columns/types via DuckDB introspection and produces both a
structured dict and a CREATE-TABLE-style text rendering used as context
for the SQL generator model.
"""
import logging
from typing import Any, Dict, List
import duckdb
logger = logging.getLogger(__name__)
class SchemaExtractor:
"""Extract schema information from a DuckDB connection."""
def __init__(self, con: duckdb.DuckDBPyConnection) -> None:
self.con = con
def extract_full_schema(self) -> Dict[str, Any]:
tables = self._extract_tables()
return {
"tables": tables,
"summary": self._generate_summary(tables),
"create_statements": self._build_create_statements(tables),
}
def _extract_tables(self) -> List[Dict[str, Any]]:
names = [r[0] for r in self.con.execute("SHOW TABLES").fetchall()]
out: List[Dict[str, Any]] = []
for name in names:
cols = self._extract_columns(name)
row_count = self._get_row_count(name)
sample = self._get_sample(name)
out.append({
"name": name,
"columns": cols,
"row_count": row_count,
"sample": sample,
})
return out
def _extract_columns(self, table: str) -> List[Dict[str, Any]]:
rows = self.con.execute(f'DESCRIBE "{table}"').fetchall()
cols: List[Dict[str, Any]] = []
for r in rows:
name, dtype = r[0], r[1]
nullable = (r[2] != "NO") if len(r) > 2 and r[2] is not None else True
cols.append({
"name": name,
"type": dtype,
"nullable": nullable,
})
return cols
def _get_row_count(self, table: str) -> int:
try:
return self.con.execute(f'SELECT COUNT(*) FROM "{table}"').fetchone()[0]
except Exception:
return 0
def _get_sample(self, table: str, n: int = 3) -> List[Dict[str, Any]]:
try:
cur = self.con.execute(f'SELECT * FROM "{table}" LIMIT {n}')
rows = cur.fetchall()
cols = [d[0] for d in cur.description or []]
return [dict(zip(cols, row)) for row in rows]
except Exception:
return []
def _generate_summary(self, tables: List[Dict[str, Any]]) -> str:
lines = []
for t in tables:
lines.append(f"Table {t['name']} ({t['row_count']:,} rows):")
for c in t["columns"]:
lines.append(f" - {c['name']}: {c['type']}")
lines.append("")
return "\n".join(lines).rstrip()
def _build_create_statements(self, tables: List[Dict[str, Any]]) -> str:
"""Render schema as CREATE TABLE statements (best format for LLM context)."""
chunks = []
for t in tables:
cols = ",\n ".join(
f'"{c["name"]}" {c["type"]}' for c in t["columns"]
)
chunks.append(f'CREATE TABLE "{t["name"]}" (\n {cols}\n);')
return "\n\n".join(chunks)
def get_schema_text(self) -> str:
"""Schema description optimized for LLM SQL generation:
- CREATE TABLE statements (canonical column names + types)
- Sample rows (so model sees realistic values)
- For low-cardinality categorical columns: list of distinct values
(gives the model a vocabulary to match user wording against)
"""
info = self.extract_full_schema()
out = [info["create_statements"]]
for t in info["tables"]:
if t["sample"]:
out.append(f"\n-- Sample rows from {t['name']}:")
for row in t["sample"]:
out.append(f"-- {row}")
# Distinct-values hints for categorical columns
for col in t["columns"]:
hint = self._distinct_values_hint(t["name"], col)
if hint:
out.append(hint)
return "\n".join(out)
def _distinct_values_hint(
self, table: str, col: Dict[str, Any], threshold: int = 25
) -> str | None:
"""If a column is categorical with <=threshold distinct values,
list them so the LLM can map user wording to actual values."""
dtype = (col.get("type") or "").upper()
# Only worth doing for string-ish columns
if not any(t in dtype for t in ("VARCHAR", "STRING", "TEXT", "CHAR")):
return None
try:
rows = self.con.execute(
f'SELECT DISTINCT "{col["name"]}" FROM "{table}" '
f'WHERE "{col["name"]}" IS NOT NULL '
f'LIMIT {threshold + 1}'
).fetchall()
except Exception:
return None
if len(rows) > threshold:
return None # too many distinct values, not categorical
if len(rows) <= 1:
return None # not informative
values = ", ".join(repr(r[0]) for r in rows)
return f"-- {table}.{col['name']} distinct values: {values}"
|