DanielRegaladoCardoso commited on
Commit
b0a300f
·
verified ·
1 Parent(s): 9c93809

Better SQL prompt: column-name mapping rules + few-shot + distinct value hints

Browse files
Files changed (1) hide show
  1. src/rag/schema_extractor.py +38 -1
src/rag/schema_extractor.py CHANGED
@@ -91,12 +91,49 @@ class SchemaExtractor:
91
  return "\n\n".join(chunks)
92
 
93
  def get_schema_text(self) -> str:
94
- """One-shot helper: full schema as CREATE statements + sample rows."""
 
 
 
 
 
95
  info = self.extract_full_schema()
96
  out = [info["create_statements"]]
 
97
  for t in info["tables"]:
98
  if t["sample"]:
99
  out.append(f"\n-- Sample rows from {t['name']}:")
100
  for row in t["sample"]:
101
  out.append(f"-- {row}")
 
 
 
 
 
 
 
102
  return "\n".join(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return "\n\n".join(chunks)
92
 
93
  def get_schema_text(self) -> str:
94
+ """Schema description optimized for LLM SQL generation:
95
+ - CREATE TABLE statements (canonical column names + types)
96
+ - Sample rows (so model sees realistic values)
97
+ - For low-cardinality categorical columns: list of distinct values
98
+ (gives the model a vocabulary to match user wording against)
99
+ """
100
  info = self.extract_full_schema()
101
  out = [info["create_statements"]]
102
+
103
  for t in info["tables"]:
104
  if t["sample"]:
105
  out.append(f"\n-- Sample rows from {t['name']}:")
106
  for row in t["sample"]:
107
  out.append(f"-- {row}")
108
+
109
+ # Distinct-values hints for categorical columns
110
+ for col in t["columns"]:
111
+ hint = self._distinct_values_hint(t["name"], col)
112
+ if hint:
113
+ out.append(hint)
114
+
115
  return "\n".join(out)
116
+
117
+ def _distinct_values_hint(
118
+ self, table: str, col: Dict[str, Any], threshold: int = 25
119
+ ) -> str | None:
120
+ """If a column is categorical with <=threshold distinct values,
121
+ list them so the LLM can map user wording to actual values."""
122
+ dtype = (col.get("type") or "").upper()
123
+ # Only worth doing for string-ish columns
124
+ if not any(t in dtype for t in ("VARCHAR", "STRING", "TEXT", "CHAR")):
125
+ return None
126
+ try:
127
+ rows = self.con.execute(
128
+ f'SELECT DISTINCT "{col["name"]}" FROM "{table}" '
129
+ f'WHERE "{col["name"]}" IS NOT NULL '
130
+ f'LIMIT {threshold + 1}'
131
+ ).fetchall()
132
+ except Exception:
133
+ return None
134
+ if len(rows) > threshold:
135
+ return None # too many distinct values, not categorical
136
+ if len(rows) <= 1:
137
+ return None # not informative
138
+ values = ", ".join(repr(r[0]) for r in rows)
139
+ return f"-- {table}.{col['name']} distinct values: {values}"