bhavika24 commited on
Commit
a3d2949
Β·
verified Β·
1 Parent(s): 8c6bb96

Upload engine.py

Browse files
Files changed (1) hide show
  1. engine.py +268 -68
engine.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  import os
3
  from functools import lru_cache
4
  from openai import OpenAI
5
- from datetime import datetime
6
  import re
7
 
8
  # =========================
@@ -37,10 +37,19 @@ def load_metadata():
37
  }
38
 
39
 
40
- def resolve_operator(op, value):
41
- # πŸ”΄ Normalize operator input
42
- op = op.lower().strip()
 
 
 
 
 
 
 
 
43
 
 
44
  OPERATOR_ALIASES = {
45
  "=": "equals",
46
  "==": "equals",
@@ -52,11 +61,24 @@ def resolve_operator(op, value):
52
  ">=": "greater_or_equal",
53
  "<=": "less_or_equal",
54
  "greater than": "greater_than",
55
- "less than": "less_than"
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
  op = OPERATOR_ALIASES.get(op, op)
59
 
 
60
  mapping = {
61
  "equals": "=",
62
  "not_equals": "!=",
@@ -65,10 +87,32 @@ def resolve_operator(op, value):
65
  "greater_or_equal": ">=",
66
  "less_or_equal": "<=",
67
  "contains": "LIKE",
 
68
  "starts_with": "LIKE",
69
  "ends_with": "LIKE",
70
  "in": "IN",
71
- "not_in": "NOT IN"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  }
73
 
74
  if op not in mapping:
@@ -76,50 +120,179 @@ def resolve_operator(op, value):
76
 
77
  sql_op = mapping[op]
78
 
79
- # πŸ”΄ Escape string values safely
 
 
 
 
 
80
  def sql_escape(val):
 
 
81
  return str(val).replace("'", "''")
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if op == "contains":
84
  return sql_op, f"'%{sql_escape(value)}%'"
85
 
 
 
 
86
  if op == "starts_with":
87
  return sql_op, f"'{sql_escape(value)}%'"
88
 
89
  if op == "ends_with":
90
  return sql_op, f"'%{sql_escape(value)}'"
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  if op in ("in", "not_in"):
93
  if not isinstance(value, list):
94
- raise ValueError("IN operator requires list")
95
- escaped = [f"'{sql_escape(v)}'" for v in value]
96
- return sql_op, f"({','.join(escaped)})"
97
-
98
- return sql_op, f"'{sql_escape(value)}'"
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
  # =========================
103
- # JOIN RESOLUTION
104
  # =========================
105
 
106
  def resolve_join_path(start_table, end_table):
 
 
 
 
107
  joins = load_metadata()["joins"]
108
-
109
- for path in joins.values():
 
 
 
 
 
 
110
  if path["start_table"] == start_table and path["end_table"] == end_table:
111
- return path["steps"]
112
 
113
  raise ValueError(
114
  f"No join path found from {start_table} to {end_table}"
115
  )
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  FIELD_ALIASES = {
119
- "join_date": "start_date",
120
- "joining_date": "start_date",
121
- "joined": "start_date",
122
- "hire_date": "start_date"
 
 
 
123
  }
124
 
125
  def resolve_field(field_name, module):
@@ -127,7 +300,7 @@ def resolve_field(field_name, module):
127
  fields = meta["fields"]
128
 
129
  # πŸ”Ή Normalize field name
130
- field_name = field_name.lower().strip()
131
  field_name = FIELD_ALIASES.get(field_name, field_name)
132
 
133
  # πŸ”Ή Validate existence
@@ -151,21 +324,6 @@ def resolve_field(field_name, module):
151
  return field
152
 
153
 
154
-
155
- def build_join_sql(base_table, steps):
156
- sql = []
157
- prev_alias = base_table # alias == table name
158
-
159
- for step in steps:
160
- alias = step["alias"]
161
- sql.append(
162
- f"{step['join_type'].upper()} JOIN {step['table']} {alias} "
163
- f"ON {prev_alias}.{step['base_column']} = {alias}.{step['foreign_column']}"
164
- )
165
- prev_alias = alias
166
-
167
- return "\n".join(sql)
168
-
169
  # =========================
170
  # JSON SAFETY
171
  # =========================
@@ -174,7 +332,12 @@ def safe_json_loads(text):
174
  try:
175
  return json.loads(text)
176
  except json.JSONDecodeError:
177
- match = re.search(r"\{.*\}", text, re.S)
 
 
 
 
 
178
  if match:
179
  return json.loads(match.group())
180
  raise ValueError("LLM returned invalid JSON")
@@ -194,7 +357,7 @@ def parse_intent(question, retries=2):
194
  if (fields := [
195
  f for f in meta["fields"]
196
  if meta["fields"][f]["module"] == module
197
- ])
198
  ])
199
 
200
  prompt = f"""
@@ -230,7 +393,7 @@ User question:
230
  for attempt in range(retries):
231
  try:
232
  res = client.chat.completions.create(
233
- model="gpt-4.1-mini",
234
  messages=[
235
  {
236
  "role": "system",
@@ -241,7 +404,7 @@ User question:
241
  temperature=0
242
  )
243
 
244
- content = res.choices[0].message.content
245
  plan = safe_json_loads(content)
246
 
247
  # βœ… NORMALIZE + STABILIZE INTENT SHAPE
@@ -253,12 +416,12 @@ User question:
253
 
254
  return plan
255
 
256
- except Exception:
257
  if attempt == retries - 1:
258
- raise ValueError("LLM failed to return valid JSON")
259
 
260
  # =========================
261
- # SQL GENERATOR
262
  # =========================
263
 
264
  def build_sql(plan):
@@ -273,7 +436,7 @@ def build_sql(plan):
273
  base_table = meta["modules"][module]["base_table"]
274
 
275
  joins = []
276
- joined_tables = set()
277
  where_clauses = []
278
 
279
  # ---------- SELECT ----------
@@ -284,7 +447,7 @@ def build_sql(plan):
284
  for f in select_fields:
285
  field = resolve_field(f, module)
286
  select_columns.append(
287
- f"{field['table']}.{field['column']}"
288
  )
289
  select_sql = ", ".join(select_columns)
290
  else:
@@ -296,47 +459,61 @@ def build_sql(plan):
296
 
297
  table = field["table"]
298
  column = field["column"]
 
299
 
 
300
  if table != base_table and table not in joined_tables:
301
- join_steps = resolve_join_path(base_table, table)
302
- joins.append(build_join_sql(base_table, join_steps))
303
- joined_tables.add(table)
304
-
305
- sql_op, sql_value = resolve_operator(f["operator"], f["value"])
306
- where_clauses.append(
307
- f"{table}.{column} {sql_op} {sql_value}"
308
- )
 
 
 
 
 
 
309
 
310
  # πŸ”΄ FIX: safe WHERE clause
311
  where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
312
 
313
  # ---------- FINAL SQL ----------
314
- sql = f"""
315
- SELECT {select_sql}
316
- FROM {base_table}
317
- {' '.join(joins)}
318
- {where_sql}
319
- LIMIT 100
320
- """
 
 
 
 
 
 
 
321
 
322
  return sql.strip()
323
 
324
 
325
-
326
-
327
  # =========================
328
  # VALIDATION
329
  # =========================
330
 
331
  def validate_sql(sql):
332
- sql = sql.lower()
333
 
334
- if not sql.startswith("select"):
335
  raise ValueError("Only SELECT allowed")
336
 
337
- forbidden = ["drop", "delete", "update", "insert", "truncate"]
338
- if any(x in sql for x in forbidden):
339
- raise ValueError("Unsafe SQL")
 
340
 
341
  return sql
342
 
@@ -365,3 +542,26 @@ def run(question):
365
  "query_plan": plan,
366
  "sql": sql
367
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  from functools import lru_cache
4
  from openai import OpenAI
5
+ from datetime import datetime, date, timedelta
6
  import re
7
 
8
  # =========================
 
37
  }
38
 
39
 
40
+ # =========================
41
+ # OPERATOR RESOLUTION (COMPLETE FIXED VERSION)
42
+ # =========================
43
+
44
+ def resolve_operator(op, value, field_type=None):
45
+ """
46
+ Resolve operator and format value based on data type
47
+ FIXED: Properly handles numeric types without quotes
48
+ """
49
+ # Normalize operator input
50
+ op = op.lower().strip().replace(" ", "_")
51
 
52
+ # Extended operator aliases for all your operators
53
  OPERATOR_ALIASES = {
54
  "=": "equals",
55
  "==": "equals",
 
61
  ">=": "greater_or_equal",
62
  "<=": "less_or_equal",
63
  "greater than": "greater_than",
64
+ "less than": "less_than",
65
+ "greaterthan": "greater_than",
66
+ "lessthan": "less_than",
67
+ "greaterthanorequal": "greater_or_equal",
68
+ "lessthanorequal": "less_or_equal",
69
+ "does_not_contain": "not_contains",
70
+ "is_blank": "is_empty",
71
+ "is_not_blank": "is_not_empty",
72
+ "on": "equals",
73
+ "date_equals": "equals",
74
+ "date_between": "between",
75
+ "startswith": "starts_with",
76
+ "endswith": "ends_with"
77
  }
78
 
79
  op = OPERATOR_ALIASES.get(op, op)
80
 
81
+ # SQL operator mapping
82
  mapping = {
83
  "equals": "=",
84
  "not_equals": "!=",
 
87
  "greater_or_equal": ">=",
88
  "less_or_equal": "<=",
89
  "contains": "LIKE",
90
+ "not_contains": "NOT LIKE",
91
  "starts_with": "LIKE",
92
  "ends_with": "LIKE",
93
  "in": "IN",
94
+ "not_in": "NOT IN",
95
+ "is_empty": "IS NULL",
96
+ "is_not_empty": "IS NOT NULL",
97
+ "between": "BETWEEN",
98
+ "not_between": "NOT BETWEEN",
99
+ "before": "<",
100
+ "after": ">",
101
+ # Date relative operators
102
+ "today": "=",
103
+ "yesterday": "=",
104
+ "tomorrow": "=",
105
+ "this_week": "BETWEEN",
106
+ "last_week": "BETWEEN",
107
+ "next_week": "BETWEEN",
108
+ "this_month": "BETWEEN",
109
+ "last_month": "BETWEEN",
110
+ "next_month": "BETWEEN",
111
+ "this_quarter": "BETWEEN",
112
+ "last_quarter": "BETWEEN",
113
+ "next_quarter": "BETWEEN",
114
+ "this_year": "BETWEEN",
115
+ "last_year": "BETWEEN"
116
  }
117
 
118
  if op not in mapping:
 
120
 
121
  sql_op = mapping[op]
122
 
123
+ # βœ… Determine if field is numeric
124
+ is_numeric = field_type in ['integer', 'decimal', 'float', 'number', 'int', 'bigint']
125
+ is_date = field_type in ['date', 'datetime', 'timestamp']
126
+ is_boolean = field_type in ['boolean', 'bool']
127
+
128
+ # Escape string values safely
129
  def sql_escape(val):
130
+ if val is None:
131
+ return 'NULL'
132
  return str(val).replace("'", "''")
133
 
134
+ # Handle NULL operators
135
+ if op in ("is_empty", "is_not_empty"):
136
+ return sql_op, ""
137
+
138
+ # Handle date relative operators
139
+ if op in ("today", "yesterday", "tomorrow", "this_week", "last_week", "next_week",
140
+ "this_month", "last_month", "next_month", "this_quarter", "last_quarter",
141
+ "next_quarter", "this_year", "last_year"):
142
+ today = date.today()
143
+
144
+ if op == "today":
145
+ return "=", f"'{today}'"
146
+ elif op == "yesterday":
147
+ return "=", f"'{today - timedelta(days=1)}'"
148
+ elif op == "tomorrow":
149
+ return "=", f"'{today + timedelta(days=1)}'"
150
+ elif op == "this_week":
151
+ start = today - timedelta(days=today.weekday())
152
+ end = start + timedelta(days=6)
153
+ return "BETWEEN", f"'{start}' AND '{end}'"
154
+ elif op == "this_month":
155
+ start = today.replace(day=1)
156
+ if today.month == 12:
157
+ end = today.replace(day=31)
158
+ else:
159
+ end = (today.replace(month=today.month+1, day=1) - timedelta(days=1))
160
+ return "BETWEEN", f"'{start}' AND '{end}'"
161
+ elif op == "this_year":
162
+ start = today.replace(month=1, day=1)
163
+ end = today.replace(month=12, day=31)
164
+ return "BETWEEN", f"'{start}' AND '{end}'"
165
+ # Add more as needed
166
+
167
+ # Handle LIKE operators
168
  if op == "contains":
169
  return sql_op, f"'%{sql_escape(value)}%'"
170
 
171
+ if op == "not_contains":
172
+ return sql_op, f"'%{sql_escape(value)}%'"
173
+
174
  if op == "starts_with":
175
  return sql_op, f"'{sql_escape(value)}%'"
176
 
177
  if op == "ends_with":
178
  return sql_op, f"'%{sql_escape(value)}'"
179
 
180
+ # Handle BETWEEN operator
181
+ if op in ("between", "not_between"):
182
+ if not isinstance(value, (list, tuple)) or len(value) != 2:
183
+ raise ValueError("BETWEEN operator requires array of 2 values")
184
+
185
+ if is_numeric:
186
+ return sql_op, f"{value[0]} AND {value[1]}"
187
+ else:
188
+ return sql_op, f"'{sql_escape(value[0])}' AND '{sql_escape(value[1])}'"
189
+
190
+ # βœ… Handle IN operators with type checking
191
  if op in ("in", "not_in"):
192
  if not isinstance(value, list):
193
+ value = [value]
194
+
195
+ if is_numeric:
196
+ escaped = [str(v) for v in value] # βœ… No quotes for numbers
197
+ else:
198
+ escaped = [f"'{sql_escape(v)}'" for v in value]
199
+
200
+ return sql_op, f"({', '.join(escaped)})"
201
+
202
+ # βœ… Handle regular comparison operators with type awareness
203
+ if is_numeric:
204
+ return sql_op, str(value) # βœ… No quotes for numbers
205
+ elif is_boolean:
206
+ if isinstance(value, bool):
207
+ return sql_op, "1" if value else "0"
208
+ return sql_op, str(value)
209
+ elif is_date:
210
+ return sql_op, f"'{sql_escape(value)}'"
211
+ else:
212
+ return sql_op, f"'{sql_escape(value)}'"
213
 
214
 
215
  # =========================
216
+ # JOIN RESOLUTION (FIXED)
217
  # =========================
218
 
219
  def resolve_join_path(start_table, end_table):
220
+ """
221
+ Find join path between two tables
222
+ FIXED: Handles your join_graph.json structure
223
+ """
224
  joins = load_metadata()["joins"]
225
+
226
+ # Try direct lookup with double underscore
227
+ key = f"{start_table}__{end_table}"
228
+ if key in joins:
229
+ return joins[key]
230
+
231
+ # Try searching by start and end table
232
+ for path_key, path in joins.items():
233
  if path["start_table"] == start_table and path["end_table"] == end_table:
234
+ return path
235
 
236
  raise ValueError(
237
  f"No join path found from {start_table} to {end_table}"
238
  )
239
 
240
 
241
+ def build_join_sql(base_table, join_path):
242
+ """
243
+ Build JOIN SQL from join path
244
+ FIXED: Properly handles multi-step joins with from_previous_step flag
245
+ """
246
+ steps = join_path["steps"]
247
+ sql = []
248
+
249
+ # Sort steps by step number
250
+ sorted_steps = sorted(steps, key=lambda x: x.get("step", 0))
251
+
252
+ for i, step in enumerate(sorted_steps):
253
+ alias = step["alias"]
254
+ table = step["table"]
255
+ join_type = step["join_type"].upper()
256
+
257
+ # βœ… Determine the left side of the join
258
+ if i == 0:
259
+ # First join always references base table
260
+ left_ref = base_table
261
+ else:
262
+ # Subsequent joins: check from_previous_step flag
263
+ if step.get("from_previous_step", False):
264
+ left_ref = sorted_steps[i-1]["alias"] # βœ… Use previous alias
265
+ else:
266
+ left_ref = base_table
267
+
268
+ # Build basic join condition
269
+ join_condition = f"{left_ref}.{step['base_column']} = {alias}.{step['foreign_column']}"
270
+
271
+ # βœ… Add extra conditions if present
272
+ if "extra_conditions" in step and step["extra_conditions"]:
273
+ for extra in step["extra_conditions"]:
274
+ condition = f"{alias}.{extra['column']} {extra['operator']} {extra['value']}"
275
+ join_condition += f" AND {condition}"
276
+
277
+ sql.append(
278
+ f"{join_type} JOIN {table} {alias} ON {join_condition}"
279
+ )
280
+
281
+ return "\n".join(sql)
282
+
283
+
284
+ # =========================
285
+ # FIELD RESOLUTION
286
+ # =========================
287
+
288
  FIELD_ALIASES = {
289
+ "join_date": "date_of_joining",
290
+ "joining_date": "date_of_joining",
291
+ "joined": "date_of_joining",
292
+ "hire_date": "date_of_joining",
293
+ "emp_code": "employee_code",
294
+ "emp_name": "full_name",
295
+ "dept": "department"
296
  }
297
 
298
  def resolve_field(field_name, module):
 
300
  fields = meta["fields"]
301
 
302
  # πŸ”Ή Normalize field name
303
+ field_name = field_name.lower().strip().replace(" ", "_")
304
  field_name = FIELD_ALIASES.get(field_name, field_name)
305
 
306
  # πŸ”Ή Validate existence
 
324
  return field
325
 
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # =========================
328
  # JSON SAFETY
329
  # =========================
 
332
  try:
333
  return json.loads(text)
334
  except json.JSONDecodeError:
335
+ # Try to extract JSON from markdown
336
+ match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
337
+ if match:
338
+ return json.loads(match.group(1))
339
+
340
+ match = re.search(r"\{.*\}", text, re.DOTALL)
341
  if match:
342
  return json.loads(match.group())
343
  raise ValueError("LLM returned invalid JSON")
 
357
  if (fields := [
358
  f for f in meta["fields"]
359
  if meta["fields"][f]["module"] == module
360
+ ][:20]) # Limit to 20 fields per module for token efficiency
361
  ])
362
 
363
  prompt = f"""
 
393
  for attempt in range(retries):
394
  try:
395
  res = client.chat.completions.create(
396
+ model="gpt-4o-mini",
397
  messages=[
398
  {
399
  "role": "system",
 
404
  temperature=0
405
  )
406
 
407
+ content = res.choices[0].message.content.strip()
408
  plan = safe_json_loads(content)
409
 
410
  # βœ… NORMALIZE + STABILIZE INTENT SHAPE
 
416
 
417
  return plan
418
 
419
+ except Exception as e:
420
  if attempt == retries - 1:
421
+ raise ValueError(f"LLM failed to return valid JSON: {str(e)}")
422
 
423
  # =========================
424
+ # SQL GENERATOR (FIXED)
425
  # =========================
426
 
427
  def build_sql(plan):
 
436
  base_table = meta["modules"][module]["base_table"]
437
 
438
  joins = []
439
+ joined_tables = {base_table} # βœ… Track all joined tables
440
  where_clauses = []
441
 
442
  # ---------- SELECT ----------
 
447
  for f in select_fields:
448
  field = resolve_field(f, module)
449
  select_columns.append(
450
+ f"{field['table']}.{field['column']} AS {f}"
451
  )
452
  select_sql = ", ".join(select_columns)
453
  else:
 
459
 
460
  table = field["table"]
461
  column = field["column"]
462
+ field_type = field.get("type") # βœ… Get field type
463
 
464
+ # Add join if needed
465
  if table != base_table and table not in joined_tables:
466
+ join_path = resolve_join_path(base_table, table)
467
+ joins.append(build_join_sql(base_table, join_path))
468
+
469
+ # βœ… Track all tables in join path
470
+ for step in join_path["steps"]:
471
+ joined_tables.add(step["table"])
472
+
473
+ # βœ… Pass field_type to resolve_operator
474
+ sql_op, sql_value = resolve_operator(f["operator"], f["value"], field_type)
475
+
476
+ if sql_value: # Has value
477
+ where_clauses.append(f"{table}.{column} {sql_op} {sql_value}")
478
+ else: # IS NULL / IS NOT NULL
479
+ where_clauses.append(f"{table}.{column} {sql_op}")
480
 
481
  # πŸ”΄ FIX: safe WHERE clause
482
  where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
483
 
484
  # ---------- FINAL SQL ----------
485
+ sql_parts = [
486
+ f"SELECT {select_sql}",
487
+ f"FROM {base_table}"
488
+ ]
489
+
490
+ if joins:
491
+ sql_parts.extend(joins)
492
+
493
+ if where_sql:
494
+ sql_parts.append(where_sql)
495
+
496
+ sql_parts.append("LIMIT 100")
497
+
498
+ sql = "\n".join(sql_parts)
499
 
500
  return sql.strip()
501
 
502
 
 
 
503
  # =========================
504
  # VALIDATION
505
  # =========================
506
 
507
  def validate_sql(sql):
508
+ sql_lower = sql.lower()
509
 
510
+ if not sql_lower.strip().startswith("select"):
511
  raise ValueError("Only SELECT allowed")
512
 
513
+ forbidden = ["drop", "delete", "update", "insert", "truncate", "alter", "create"]
514
+ for keyword in forbidden:
515
+ if re.search(rf'\b{keyword}\b', sql_lower):
516
+ raise ValueError(f"Unsafe SQL: '{keyword}' not allowed")
517
 
518
  return sql
519
 
 
542
  "query_plan": plan,
543
  "sql": sql
544
  }
545
+
546
+
547
+ # =========================
548
+ # TEST
549
+ # =========================
550
+
551
+ if __name__ == "__main__":
552
+ test_queries = [
553
+ "Show all employees",
554
+ "Find departments with more than 50 employees",
555
+ "Show employees in departments 1, 2, 3",
556
+ "List employees who joined this month"
557
+ ]
558
+
559
+ for q in test_queries:
560
+ print(f"\n{'='*80}")
561
+ print(f"Q: {q}")
562
+ print('='*80)
563
+ try:
564
+ result = run(q)
565
+ print("SQL:", result["sql"])
566
+ except Exception as e:
567
+ print("ERROR:", e)