Spaces:
Sleeping
Sleeping
Better SQL prompt: column-name mapping rules + few-shot + distinct value hints
Browse files- src/rag/schema_extractor.py +38 -1
src/rag/schema_extractor.py
CHANGED
|
@@ -91,12 +91,49 @@ class SchemaExtractor:
|
|
| 91 |
return "\n\n".join(chunks)
|
| 92 |
|
| 93 |
def get_schema_text(self) -> str:
|
| 94 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
info = self.extract_full_schema()
|
| 96 |
out = [info["create_statements"]]
|
|
|
|
| 97 |
for t in info["tables"]:
|
| 98 |
if t["sample"]:
|
| 99 |
out.append(f"\n-- Sample rows from {t['name']}:")
|
| 100 |
for row in t["sample"]:
|
| 101 |
out.append(f"-- {row}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
return "\n".join(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
return "\n\n".join(chunks)
|
| 92 |
|
| 93 |
def get_schema_text(self) -> str:
|
| 94 |
+
"""Schema description optimized for LLM SQL generation:
|
| 95 |
+
- CREATE TABLE statements (canonical column names + types)
|
| 96 |
+
- Sample rows (so model sees realistic values)
|
| 97 |
+
- For low-cardinality categorical columns: list of distinct values
|
| 98 |
+
(gives the model a vocabulary to match user wording against)
|
| 99 |
+
"""
|
| 100 |
info = self.extract_full_schema()
|
| 101 |
out = [info["create_statements"]]
|
| 102 |
+
|
| 103 |
for t in info["tables"]:
|
| 104 |
if t["sample"]:
|
| 105 |
out.append(f"\n-- Sample rows from {t['name']}:")
|
| 106 |
for row in t["sample"]:
|
| 107 |
out.append(f"-- {row}")
|
| 108 |
+
|
| 109 |
+
# Distinct-values hints for categorical columns
|
| 110 |
+
for col in t["columns"]:
|
| 111 |
+
hint = self._distinct_values_hint(t["name"], col)
|
| 112 |
+
if hint:
|
| 113 |
+
out.append(hint)
|
| 114 |
+
|
| 115 |
return "\n".join(out)
|
| 116 |
+
|
| 117 |
+
def _distinct_values_hint(
|
| 118 |
+
self, table: str, col: Dict[str, Any], threshold: int = 25
|
| 119 |
+
) -> str | None:
|
| 120 |
+
"""If a column is categorical with <=threshold distinct values,
|
| 121 |
+
list them so the LLM can map user wording to actual values."""
|
| 122 |
+
dtype = (col.get("type") or "").upper()
|
| 123 |
+
# Only worth doing for string-ish columns
|
| 124 |
+
if not any(t in dtype for t in ("VARCHAR", "STRING", "TEXT", "CHAR")):
|
| 125 |
+
return None
|
| 126 |
+
try:
|
| 127 |
+
rows = self.con.execute(
|
| 128 |
+
f'SELECT DISTINCT "{col["name"]}" FROM "{table}" '
|
| 129 |
+
f'WHERE "{col["name"]}" IS NOT NULL '
|
| 130 |
+
f'LIMIT {threshold + 1}'
|
| 131 |
+
).fetchall()
|
| 132 |
+
except Exception:
|
| 133 |
+
return None
|
| 134 |
+
if len(rows) > threshold:
|
| 135 |
+
return None # too many distinct values, not categorical
|
| 136 |
+
if len(rows) <= 1:
|
| 137 |
+
return None # not informative
|
| 138 |
+
values = ", ".join(repr(r[0]) for r in rows)
|
| 139 |
+
return f"-- {table}.{col['name']} distinct values: {values}"
|