Spaces:
Sleeping
Sleeping
Upload engine.py
Browse files
engine.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
| 2 |
import os
|
| 3 |
from functools import lru_cache
|
| 4 |
from openai import OpenAI
|
| 5 |
-
from datetime import datetime
|
| 6 |
import re
|
| 7 |
|
| 8 |
# =========================
|
|
@@ -37,10 +37,19 @@ def load_metadata():
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
|
|
|
| 44 |
OPERATOR_ALIASES = {
|
| 45 |
"=": "equals",
|
| 46 |
"==": "equals",
|
|
@@ -52,11 +61,24 @@ def resolve_operator(op, value):
|
|
| 52 |
">=": "greater_or_equal",
|
| 53 |
"<=": "less_or_equal",
|
| 54 |
"greater than": "greater_than",
|
| 55 |
-
"less than": "less_than"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
op = OPERATOR_ALIASES.get(op, op)
|
| 59 |
|
|
|
|
| 60 |
mapping = {
|
| 61 |
"equals": "=",
|
| 62 |
"not_equals": "!=",
|
|
@@ -65,10 +87,32 @@ def resolve_operator(op, value):
|
|
| 65 |
"greater_or_equal": ">=",
|
| 66 |
"less_or_equal": "<=",
|
| 67 |
"contains": "LIKE",
|
|
|
|
| 68 |
"starts_with": "LIKE",
|
| 69 |
"ends_with": "LIKE",
|
| 70 |
"in": "IN",
|
| 71 |
-
"not_in": "NOT IN"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
|
| 74 |
if op not in mapping:
|
|
@@ -76,50 +120,179 @@ def resolve_operator(op, value):
|
|
| 76 |
|
| 77 |
sql_op = mapping[op]
|
| 78 |
|
| 79 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def sql_escape(val):
|
|
|
|
|
|
|
| 81 |
return str(val).replace("'", "''")
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
if op == "contains":
|
| 84 |
return sql_op, f"'%{sql_escape(value)}%'"
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
if op == "starts_with":
|
| 87 |
return sql_op, f"'{sql_escape(value)}%'"
|
| 88 |
|
| 89 |
if op == "ends_with":
|
| 90 |
return sql_op, f"'%{sql_escape(value)}'"
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
if op in ("in", "not_in"):
|
| 93 |
if not isinstance(value, list):
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
# =========================
|
| 103 |
-
# JOIN RESOLUTION
|
| 104 |
# =========================
|
| 105 |
|
| 106 |
def resolve_join_path(start_table, end_table):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
joins = load_metadata()["joins"]
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
if path["start_table"] == start_table and path["end_table"] == end_table:
|
| 111 |
-
return path
|
| 112 |
|
| 113 |
raise ValueError(
|
| 114 |
f"No join path found from {start_table} to {end_table}"
|
| 115 |
)
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
FIELD_ALIASES = {
|
| 119 |
-
"join_date": "
|
| 120 |
-
"joining_date": "
|
| 121 |
-
"joined": "
|
| 122 |
-
"hire_date": "
|
|
|
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
def resolve_field(field_name, module):
|
|
@@ -127,7 +300,7 @@ def resolve_field(field_name, module):
|
|
| 127 |
fields = meta["fields"]
|
| 128 |
|
| 129 |
# πΉ Normalize field name
|
| 130 |
-
field_name = field_name.lower().strip()
|
| 131 |
field_name = FIELD_ALIASES.get(field_name, field_name)
|
| 132 |
|
| 133 |
# πΉ Validate existence
|
|
@@ -151,21 +324,6 @@ def resolve_field(field_name, module):
|
|
| 151 |
return field
|
| 152 |
|
| 153 |
|
| 154 |
-
|
| 155 |
-
def build_join_sql(base_table, steps):
|
| 156 |
-
sql = []
|
| 157 |
-
prev_alias = base_table # alias == table name
|
| 158 |
-
|
| 159 |
-
for step in steps:
|
| 160 |
-
alias = step["alias"]
|
| 161 |
-
sql.append(
|
| 162 |
-
f"{step['join_type'].upper()} JOIN {step['table']} {alias} "
|
| 163 |
-
f"ON {prev_alias}.{step['base_column']} = {alias}.{step['foreign_column']}"
|
| 164 |
-
)
|
| 165 |
-
prev_alias = alias
|
| 166 |
-
|
| 167 |
-
return "\n".join(sql)
|
| 168 |
-
|
| 169 |
# =========================
|
| 170 |
# JSON SAFETY
|
| 171 |
# =========================
|
|
@@ -174,7 +332,12 @@ def safe_json_loads(text):
|
|
| 174 |
try:
|
| 175 |
return json.loads(text)
|
| 176 |
except json.JSONDecodeError:
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
if match:
|
| 179 |
return json.loads(match.group())
|
| 180 |
raise ValueError("LLM returned invalid JSON")
|
|
@@ -194,7 +357,7 @@ def parse_intent(question, retries=2):
|
|
| 194 |
if (fields := [
|
| 195 |
f for f in meta["fields"]
|
| 196 |
if meta["fields"][f]["module"] == module
|
| 197 |
-
])
|
| 198 |
])
|
| 199 |
|
| 200 |
prompt = f"""
|
|
@@ -230,7 +393,7 @@ User question:
|
|
| 230 |
for attempt in range(retries):
|
| 231 |
try:
|
| 232 |
res = client.chat.completions.create(
|
| 233 |
-
model="gpt-
|
| 234 |
messages=[
|
| 235 |
{
|
| 236 |
"role": "system",
|
|
@@ -241,7 +404,7 @@ User question:
|
|
| 241 |
temperature=0
|
| 242 |
)
|
| 243 |
|
| 244 |
-
content = res.choices[0].message.content
|
| 245 |
plan = safe_json_loads(content)
|
| 246 |
|
| 247 |
# β
NORMALIZE + STABILIZE INTENT SHAPE
|
|
@@ -253,12 +416,12 @@ User question:
|
|
| 253 |
|
| 254 |
return plan
|
| 255 |
|
| 256 |
-
except Exception:
|
| 257 |
if attempt == retries - 1:
|
| 258 |
-
raise ValueError("LLM failed to return valid JSON")
|
| 259 |
|
| 260 |
# =========================
|
| 261 |
-
# SQL GENERATOR
|
| 262 |
# =========================
|
| 263 |
|
| 264 |
def build_sql(plan):
|
|
@@ -273,7 +436,7 @@ def build_sql(plan):
|
|
| 273 |
base_table = meta["modules"][module]["base_table"]
|
| 274 |
|
| 275 |
joins = []
|
| 276 |
-
joined_tables =
|
| 277 |
where_clauses = []
|
| 278 |
|
| 279 |
# ---------- SELECT ----------
|
|
@@ -284,7 +447,7 @@ def build_sql(plan):
|
|
| 284 |
for f in select_fields:
|
| 285 |
field = resolve_field(f, module)
|
| 286 |
select_columns.append(
|
| 287 |
-
f"{field['table']}.{field['column']}"
|
| 288 |
)
|
| 289 |
select_sql = ", ".join(select_columns)
|
| 290 |
else:
|
|
@@ -296,47 +459,61 @@ def build_sql(plan):
|
|
| 296 |
|
| 297 |
table = field["table"]
|
| 298 |
column = field["column"]
|
|
|
|
| 299 |
|
|
|
|
| 300 |
if table != base_table and table not in joined_tables:
|
| 301 |
-
|
| 302 |
-
joins.append(build_join_sql(base_table,
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
# π΄ FIX: safe WHERE clause
|
| 311 |
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
| 312 |
|
| 313 |
# ---------- FINAL SQL ----------
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
return sql.strip()
|
| 323 |
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
# =========================
|
| 328 |
# VALIDATION
|
| 329 |
# =========================
|
| 330 |
|
| 331 |
def validate_sql(sql):
|
| 332 |
-
|
| 333 |
|
| 334 |
-
if not
|
| 335 |
raise ValueError("Only SELECT allowed")
|
| 336 |
|
| 337 |
-
forbidden = ["drop", "delete", "update", "insert", "truncate"]
|
| 338 |
-
|
| 339 |
-
|
|
|
|
| 340 |
|
| 341 |
return sql
|
| 342 |
|
|
@@ -365,3 +542,26 @@ def run(question):
|
|
| 365 |
"query_plan": plan,
|
| 366 |
"sql": sql
|
| 367 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
from functools import lru_cache
|
| 4 |
from openai import OpenAI
|
| 5 |
+
from datetime import datetime, date, timedelta
|
| 6 |
import re
|
| 7 |
|
| 8 |
# =========================
|
|
|
|
| 37 |
}
|
| 38 |
|
| 39 |
|
| 40 |
+
# =========================
|
| 41 |
+
# OPERATOR RESOLUTION (COMPLETE FIXED VERSION)
|
| 42 |
+
# =========================
|
| 43 |
+
|
| 44 |
+
def resolve_operator(op, value, field_type=None):
|
| 45 |
+
"""
|
| 46 |
+
Resolve operator and format value based on data type
|
| 47 |
+
FIXED: Properly handles numeric types without quotes
|
| 48 |
+
"""
|
| 49 |
+
# Normalize operator input
|
| 50 |
+
op = op.lower().strip().replace(" ", "_")
|
| 51 |
|
| 52 |
+
# Extended operator aliases for all your operators
|
| 53 |
OPERATOR_ALIASES = {
|
| 54 |
"=": "equals",
|
| 55 |
"==": "equals",
|
|
|
|
| 61 |
">=": "greater_or_equal",
|
| 62 |
"<=": "less_or_equal",
|
| 63 |
"greater than": "greater_than",
|
| 64 |
+
"less than": "less_than",
|
| 65 |
+
"greaterthan": "greater_than",
|
| 66 |
+
"lessthan": "less_than",
|
| 67 |
+
"greaterthanorequal": "greater_or_equal",
|
| 68 |
+
"lessthanorequal": "less_or_equal",
|
| 69 |
+
"does_not_contain": "not_contains",
|
| 70 |
+
"is_blank": "is_empty",
|
| 71 |
+
"is_not_blank": "is_not_empty",
|
| 72 |
+
"on": "equals",
|
| 73 |
+
"date_equals": "equals",
|
| 74 |
+
"date_between": "between",
|
| 75 |
+
"startswith": "starts_with",
|
| 76 |
+
"endswith": "ends_with"
|
| 77 |
}
|
| 78 |
|
| 79 |
op = OPERATOR_ALIASES.get(op, op)
|
| 80 |
|
| 81 |
+
# SQL operator mapping
|
| 82 |
mapping = {
|
| 83 |
"equals": "=",
|
| 84 |
"not_equals": "!=",
|
|
|
|
| 87 |
"greater_or_equal": ">=",
|
| 88 |
"less_or_equal": "<=",
|
| 89 |
"contains": "LIKE",
|
| 90 |
+
"not_contains": "NOT LIKE",
|
| 91 |
"starts_with": "LIKE",
|
| 92 |
"ends_with": "LIKE",
|
| 93 |
"in": "IN",
|
| 94 |
+
"not_in": "NOT IN",
|
| 95 |
+
"is_empty": "IS NULL",
|
| 96 |
+
"is_not_empty": "IS NOT NULL",
|
| 97 |
+
"between": "BETWEEN",
|
| 98 |
+
"not_between": "NOT BETWEEN",
|
| 99 |
+
"before": "<",
|
| 100 |
+
"after": ">",
|
| 101 |
+
# Date relative operators
|
| 102 |
+
"today": "=",
|
| 103 |
+
"yesterday": "=",
|
| 104 |
+
"tomorrow": "=",
|
| 105 |
+
"this_week": "BETWEEN",
|
| 106 |
+
"last_week": "BETWEEN",
|
| 107 |
+
"next_week": "BETWEEN",
|
| 108 |
+
"this_month": "BETWEEN",
|
| 109 |
+
"last_month": "BETWEEN",
|
| 110 |
+
"next_month": "BETWEEN",
|
| 111 |
+
"this_quarter": "BETWEEN",
|
| 112 |
+
"last_quarter": "BETWEEN",
|
| 113 |
+
"next_quarter": "BETWEEN",
|
| 114 |
+
"this_year": "BETWEEN",
|
| 115 |
+
"last_year": "BETWEEN"
|
| 116 |
}
|
| 117 |
|
| 118 |
if op not in mapping:
|
|
|
|
| 120 |
|
| 121 |
sql_op = mapping[op]
|
| 122 |
|
| 123 |
+
# β
Determine if field is numeric
|
| 124 |
+
is_numeric = field_type in ['integer', 'decimal', 'float', 'number', 'int', 'bigint']
|
| 125 |
+
is_date = field_type in ['date', 'datetime', 'timestamp']
|
| 126 |
+
is_boolean = field_type in ['boolean', 'bool']
|
| 127 |
+
|
| 128 |
+
# Escape string values safely
|
| 129 |
def sql_escape(val):
|
| 130 |
+
if val is None:
|
| 131 |
+
return 'NULL'
|
| 132 |
return str(val).replace("'", "''")
|
| 133 |
|
| 134 |
+
# Handle NULL operators
|
| 135 |
+
if op in ("is_empty", "is_not_empty"):
|
| 136 |
+
return sql_op, ""
|
| 137 |
+
|
| 138 |
+
# Handle date relative operators
|
| 139 |
+
if op in ("today", "yesterday", "tomorrow", "this_week", "last_week", "next_week",
|
| 140 |
+
"this_month", "last_month", "next_month", "this_quarter", "last_quarter",
|
| 141 |
+
"next_quarter", "this_year", "last_year"):
|
| 142 |
+
today = date.today()
|
| 143 |
+
|
| 144 |
+
if op == "today":
|
| 145 |
+
return "=", f"'{today}'"
|
| 146 |
+
elif op == "yesterday":
|
| 147 |
+
return "=", f"'{today - timedelta(days=1)}'"
|
| 148 |
+
elif op == "tomorrow":
|
| 149 |
+
return "=", f"'{today + timedelta(days=1)}'"
|
| 150 |
+
elif op == "this_week":
|
| 151 |
+
start = today - timedelta(days=today.weekday())
|
| 152 |
+
end = start + timedelta(days=6)
|
| 153 |
+
return "BETWEEN", f"'{start}' AND '{end}'"
|
| 154 |
+
elif op == "this_month":
|
| 155 |
+
start = today.replace(day=1)
|
| 156 |
+
if today.month == 12:
|
| 157 |
+
end = today.replace(day=31)
|
| 158 |
+
else:
|
| 159 |
+
end = (today.replace(month=today.month+1, day=1) - timedelta(days=1))
|
| 160 |
+
return "BETWEEN", f"'{start}' AND '{end}'"
|
| 161 |
+
elif op == "this_year":
|
| 162 |
+
start = today.replace(month=1, day=1)
|
| 163 |
+
end = today.replace(month=12, day=31)
|
| 164 |
+
return "BETWEEN", f"'{start}' AND '{end}'"
|
| 165 |
+
# Add more as needed
|
| 166 |
+
|
| 167 |
+
# Handle LIKE operators
|
| 168 |
if op == "contains":
|
| 169 |
return sql_op, f"'%{sql_escape(value)}%'"
|
| 170 |
|
| 171 |
+
if op == "not_contains":
|
| 172 |
+
return sql_op, f"'%{sql_escape(value)}%'"
|
| 173 |
+
|
| 174 |
if op == "starts_with":
|
| 175 |
return sql_op, f"'{sql_escape(value)}%'"
|
| 176 |
|
| 177 |
if op == "ends_with":
|
| 178 |
return sql_op, f"'%{sql_escape(value)}'"
|
| 179 |
|
| 180 |
+
# Handle BETWEEN operator
|
| 181 |
+
if op in ("between", "not_between"):
|
| 182 |
+
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
| 183 |
+
raise ValueError("BETWEEN operator requires array of 2 values")
|
| 184 |
+
|
| 185 |
+
if is_numeric:
|
| 186 |
+
return sql_op, f"{value[0]} AND {value[1]}"
|
| 187 |
+
else:
|
| 188 |
+
return sql_op, f"'{sql_escape(value[0])}' AND '{sql_escape(value[1])}'"
|
| 189 |
+
|
| 190 |
+
# β
Handle IN operators with type checking
|
| 191 |
if op in ("in", "not_in"):
|
| 192 |
if not isinstance(value, list):
|
| 193 |
+
value = [value]
|
| 194 |
+
|
| 195 |
+
if is_numeric:
|
| 196 |
+
escaped = [str(v) for v in value] # β
No quotes for numbers
|
| 197 |
+
else:
|
| 198 |
+
escaped = [f"'{sql_escape(v)}'" for v in value]
|
| 199 |
+
|
| 200 |
+
return sql_op, f"({', '.join(escaped)})"
|
| 201 |
+
|
| 202 |
+
# β
Handle regular comparison operators with type awareness
|
| 203 |
+
if is_numeric:
|
| 204 |
+
return sql_op, str(value) # β
No quotes for numbers
|
| 205 |
+
elif is_boolean:
|
| 206 |
+
if isinstance(value, bool):
|
| 207 |
+
return sql_op, "1" if value else "0"
|
| 208 |
+
return sql_op, str(value)
|
| 209 |
+
elif is_date:
|
| 210 |
+
return sql_op, f"'{sql_escape(value)}'"
|
| 211 |
+
else:
|
| 212 |
+
return sql_op, f"'{sql_escape(value)}'"
|
| 213 |
|
| 214 |
|
| 215 |
# =========================
|
| 216 |
+
# JOIN RESOLUTION (FIXED)
|
| 217 |
# =========================
|
| 218 |
|
| 219 |
def resolve_join_path(start_table, end_table):
|
| 220 |
+
"""
|
| 221 |
+
Find join path between two tables
|
| 222 |
+
FIXED: Handles your join_graph.json structure
|
| 223 |
+
"""
|
| 224 |
joins = load_metadata()["joins"]
|
| 225 |
+
|
| 226 |
+
# Try direct lookup with double underscore
|
| 227 |
+
key = f"{start_table}__{end_table}"
|
| 228 |
+
if key in joins:
|
| 229 |
+
return joins[key]
|
| 230 |
+
|
| 231 |
+
# Try searching by start and end table
|
| 232 |
+
for path_key, path in joins.items():
|
| 233 |
if path["start_table"] == start_table and path["end_table"] == end_table:
|
| 234 |
+
return path
|
| 235 |
|
| 236 |
raise ValueError(
|
| 237 |
f"No join path found from {start_table} to {end_table}"
|
| 238 |
)
|
| 239 |
|
| 240 |
|
| 241 |
+
def build_join_sql(base_table, join_path):
|
| 242 |
+
"""
|
| 243 |
+
Build JOIN SQL from join path
|
| 244 |
+
FIXED: Properly handles multi-step joins with from_previous_step flag
|
| 245 |
+
"""
|
| 246 |
+
steps = join_path["steps"]
|
| 247 |
+
sql = []
|
| 248 |
+
|
| 249 |
+
# Sort steps by step number
|
| 250 |
+
sorted_steps = sorted(steps, key=lambda x: x.get("step", 0))
|
| 251 |
+
|
| 252 |
+
for i, step in enumerate(sorted_steps):
|
| 253 |
+
alias = step["alias"]
|
| 254 |
+
table = step["table"]
|
| 255 |
+
join_type = step["join_type"].upper()
|
| 256 |
+
|
| 257 |
+
# β
Determine the left side of the join
|
| 258 |
+
if i == 0:
|
| 259 |
+
# First join always references base table
|
| 260 |
+
left_ref = base_table
|
| 261 |
+
else:
|
| 262 |
+
# Subsequent joins: check from_previous_step flag
|
| 263 |
+
if step.get("from_previous_step", False):
|
| 264 |
+
left_ref = sorted_steps[i-1]["alias"] # β
Use previous alias
|
| 265 |
+
else:
|
| 266 |
+
left_ref = base_table
|
| 267 |
+
|
| 268 |
+
# Build basic join condition
|
| 269 |
+
join_condition = f"{left_ref}.{step['base_column']} = {alias}.{step['foreign_column']}"
|
| 270 |
+
|
| 271 |
+
# β
Add extra conditions if present
|
| 272 |
+
if "extra_conditions" in step and step["extra_conditions"]:
|
| 273 |
+
for extra in step["extra_conditions"]:
|
| 274 |
+
condition = f"{alias}.{extra['column']} {extra['operator']} {extra['value']}"
|
| 275 |
+
join_condition += f" AND {condition}"
|
| 276 |
+
|
| 277 |
+
sql.append(
|
| 278 |
+
f"{join_type} JOIN {table} {alias} ON {join_condition}"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return "\n".join(sql)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# =========================
|
| 285 |
+
# FIELD RESOLUTION
|
| 286 |
+
# =========================
|
| 287 |
+
|
| 288 |
FIELD_ALIASES = {
|
| 289 |
+
"join_date": "date_of_joining",
|
| 290 |
+
"joining_date": "date_of_joining",
|
| 291 |
+
"joined": "date_of_joining",
|
| 292 |
+
"hire_date": "date_of_joining",
|
| 293 |
+
"emp_code": "employee_code",
|
| 294 |
+
"emp_name": "full_name",
|
| 295 |
+
"dept": "department"
|
| 296 |
}
|
| 297 |
|
| 298 |
def resolve_field(field_name, module):
|
|
|
|
| 300 |
fields = meta["fields"]
|
| 301 |
|
| 302 |
# πΉ Normalize field name
|
| 303 |
+
field_name = field_name.lower().strip().replace(" ", "_")
|
| 304 |
field_name = FIELD_ALIASES.get(field_name, field_name)
|
| 305 |
|
| 306 |
# πΉ Validate existence
|
|
|
|
| 324 |
return field
|
| 325 |
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
# =========================
|
| 328 |
# JSON SAFETY
|
| 329 |
# =========================
|
|
|
|
| 332 |
try:
|
| 333 |
return json.loads(text)
|
| 334 |
except json.JSONDecodeError:
|
| 335 |
+
# Try to extract JSON from markdown
|
| 336 |
+
match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
| 337 |
+
if match:
|
| 338 |
+
return json.loads(match.group(1))
|
| 339 |
+
|
| 340 |
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
| 341 |
if match:
|
| 342 |
return json.loads(match.group())
|
| 343 |
raise ValueError("LLM returned invalid JSON")
|
|
|
|
| 357 |
if (fields := [
|
| 358 |
f for f in meta["fields"]
|
| 359 |
if meta["fields"][f]["module"] == module
|
| 360 |
+
][:20]) # Limit to 20 fields per module for token efficiency
|
| 361 |
])
|
| 362 |
|
| 363 |
prompt = f"""
|
|
|
|
| 393 |
for attempt in range(retries):
|
| 394 |
try:
|
| 395 |
res = client.chat.completions.create(
|
| 396 |
+
model="gpt-4o-mini",
|
| 397 |
messages=[
|
| 398 |
{
|
| 399 |
"role": "system",
|
|
|
|
| 404 |
temperature=0
|
| 405 |
)
|
| 406 |
|
| 407 |
+
content = res.choices[0].message.content.strip()
|
| 408 |
plan = safe_json_loads(content)
|
| 409 |
|
| 410 |
# β
NORMALIZE + STABILIZE INTENT SHAPE
|
|
|
|
| 416 |
|
| 417 |
return plan
|
| 418 |
|
| 419 |
+
except Exception as e:
|
| 420 |
if attempt == retries - 1:
|
| 421 |
+
raise ValueError(f"LLM failed to return valid JSON: {str(e)}")
|
| 422 |
|
| 423 |
# =========================
|
| 424 |
+
# SQL GENERATOR (FIXED)
|
| 425 |
# =========================
|
| 426 |
|
| 427 |
def build_sql(plan):
|
|
|
|
| 436 |
base_table = meta["modules"][module]["base_table"]
|
| 437 |
|
| 438 |
joins = []
|
| 439 |
+
joined_tables = {base_table} # β
Track all joined tables
|
| 440 |
where_clauses = []
|
| 441 |
|
| 442 |
# ---------- SELECT ----------
|
|
|
|
| 447 |
for f in select_fields:
|
| 448 |
field = resolve_field(f, module)
|
| 449 |
select_columns.append(
|
| 450 |
+
f"{field['table']}.{field['column']} AS {f}"
|
| 451 |
)
|
| 452 |
select_sql = ", ".join(select_columns)
|
| 453 |
else:
|
|
|
|
| 459 |
|
| 460 |
table = field["table"]
|
| 461 |
column = field["column"]
|
| 462 |
+
field_type = field.get("type") # β
Get field type
|
| 463 |
|
| 464 |
+
# Add join if needed
|
| 465 |
if table != base_table and table not in joined_tables:
|
| 466 |
+
join_path = resolve_join_path(base_table, table)
|
| 467 |
+
joins.append(build_join_sql(base_table, join_path))
|
| 468 |
+
|
| 469 |
+
# β
Track all tables in join path
|
| 470 |
+
for step in join_path["steps"]:
|
| 471 |
+
joined_tables.add(step["table"])
|
| 472 |
+
|
| 473 |
+
# β
Pass field_type to resolve_operator
|
| 474 |
+
sql_op, sql_value = resolve_operator(f["operator"], f["value"], field_type)
|
| 475 |
+
|
| 476 |
+
if sql_value: # Has value
|
| 477 |
+
where_clauses.append(f"{table}.{column} {sql_op} {sql_value}")
|
| 478 |
+
else: # IS NULL / IS NOT NULL
|
| 479 |
+
where_clauses.append(f"{table}.{column} {sql_op}")
|
| 480 |
|
| 481 |
# π΄ FIX: safe WHERE clause
|
| 482 |
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
| 483 |
|
| 484 |
# ---------- FINAL SQL ----------
|
| 485 |
+
sql_parts = [
|
| 486 |
+
f"SELECT {select_sql}",
|
| 487 |
+
f"FROM {base_table}"
|
| 488 |
+
]
|
| 489 |
+
|
| 490 |
+
if joins:
|
| 491 |
+
sql_parts.extend(joins)
|
| 492 |
+
|
| 493 |
+
if where_sql:
|
| 494 |
+
sql_parts.append(where_sql)
|
| 495 |
+
|
| 496 |
+
sql_parts.append("LIMIT 100")
|
| 497 |
+
|
| 498 |
+
sql = "\n".join(sql_parts)
|
| 499 |
|
| 500 |
return sql.strip()
|
| 501 |
|
| 502 |
|
|
|
|
|
|
|
| 503 |
# =========================
|
| 504 |
# VALIDATION
|
| 505 |
# =========================
|
| 506 |
|
| 507 |
def validate_sql(sql):
|
| 508 |
+
sql_lower = sql.lower()
|
| 509 |
|
| 510 |
+
if not sql_lower.strip().startswith("select"):
|
| 511 |
raise ValueError("Only SELECT allowed")
|
| 512 |
|
| 513 |
+
forbidden = ["drop", "delete", "update", "insert", "truncate", "alter", "create"]
|
| 514 |
+
for keyword in forbidden:
|
| 515 |
+
if re.search(rf'\b{keyword}\b', sql_lower):
|
| 516 |
+
raise ValueError(f"Unsafe SQL: '{keyword}' not allowed")
|
| 517 |
|
| 518 |
return sql
|
| 519 |
|
|
|
|
| 542 |
"query_plan": plan,
|
| 543 |
"sql": sql
|
| 544 |
}
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
# =========================
|
| 548 |
+
# TEST
|
| 549 |
+
# =========================
|
| 550 |
+
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
test_queries = [
|
| 553 |
+
"Show all employees",
|
| 554 |
+
"Find departments with more than 50 employees",
|
| 555 |
+
"Show employees in departments 1, 2, 3",
|
| 556 |
+
"List employees who joined this month"
|
| 557 |
+
]
|
| 558 |
+
|
| 559 |
+
for q in test_queries:
|
| 560 |
+
print(f"\n{'='*80}")
|
| 561 |
+
print(f"Q: {q}")
|
| 562 |
+
print('='*80)
|
| 563 |
+
try:
|
| 564 |
+
result = run(q)
|
| 565 |
+
print("SQL:", result["sql"])
|
| 566 |
+
except Exception as e:
|
| 567 |
+
print("ERROR:", e)
|