import json import os from functools import lru_cache from openai import OpenAI from datetime import datetime # ========================= # CONFIG # ========================= client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # ========================= # METADATA LOADING # ========================= @lru_cache(maxsize=1) def load_metadata(): with open("modules.json") as f: modules = json.load(f) with open("join_graph.json") as f: joins = json.load(f) with open("field_types.json") as f: field_types = json.load(f) with open("fields.json") as f: fields = json.load(f) return { "modules": modules, "joins": joins, "field_types": field_types, "fields": fields } def resolve_operator(op, value): mapping = { "equals": "=", "not_equals": "!=", "greater_than": ">", "less_than": "<", "greater_or_equal": ">=", "less_or_equal": "<=", "contains": "LIKE", "starts_with": "LIKE", "ends_with": "LIKE", "in": "IN", "not_in": "NOT IN" } if op not in mapping: raise ValueError(f"Unsupported operator: {op}") sql_op = mapping[op] if op == "contains": return sql_op, f"'%{value}%'" if op == "starts_with": return sql_op, f"'{value}%'" if op == "ends_with": return sql_op, f"'%{value}'" if op in ("in", "not_in"): if not isinstance(value, list): raise ValueError("IN operator requires list") return sql_op, f"({','.join(map(repr, value))})" return sql_op, f"'{value}'" # ========================= # JOIN RESOLUTION # ========================= def resolve_join_path(start_table, end_table): joins = load_metadata()["joins"] for path in joins.values(): if path["start_table"] == start_table and path["end_table"] == end_table: return path["steps"] raise ValueError( f"No join path found from {start_table} to {end_table}" ) FIELD_ALIASES = { "join_date": "start_date", "joining_date": "start_date", "joined": "start_date", "hire_date": "start_date" } def resolve_field(field_name, module): meta = load_metadata() fields = meta["fields"] # 🔹 Normalize field name field_name = field_name.lower().strip() field_name = FIELD_ALIASES.get(field_name, field_name) # 🔹 Validate existence if field_name not in fields: raise ValueError(f"Unknown field: {field_name}") field = fields[field_name] # 🔹 Validate module if field["module"] != module: raise ValueError( f"Field '{field_name}' does not belong to module '{module}'" ) # 🔹 Validate mapping if "table" not in field or "column" not in field: raise ValueError( f"Field '{field_name}' is missing table/column mapping" ) return field def build_join_sql(base_table, steps): sql = [] prev_alias = base_table # alias == table name for step in steps: alias = step["alias"] sql.append( f"{step['join_type'].upper()} JOIN {step['table']} {alias} " f"ON {prev_alias}.{step['base_column']} = {alias}.{step['foreign_column']}" ) prev_alias = alias return "\n".join(sql) # ========================= # INTENT PARSING (LLM) # ========================= def parse_intent(question): meta = load_metadata() # ✅ Build schema safely (skip empty modules) schema_description = "\n".join([ f"{module}: {', '.join(fields)}" for module in meta["modules"] if (fields := [ f for f in meta["fields"] if meta["fields"][f]["module"] == module ]) ]) prompt = f""" You are a SQL query planner. You MUST only use fields listed below. If a field does not exist, choose the closest valid field. Do NOT invent column names. Available schema: {schema_description} Extract: - module - filters (field, operator, value) - selected fields Return ONLY valid JSON. User question: {question} """ res = client.chat.completions.create( model="gpt-4.1-mini", messages=[{"role": "user", "content": prompt}], temperature=0 ) # ✅ Safe JSON parsing try: return json.loads(res.choices[0].message.content) except json.JSONDecodeError: raise ValueError("LLM returned invalid JSON") # ========================= # SQL GENERATOR # ========================= def build_sql(plan): meta = load_metadata() module = plan["module"] if module not in meta["modules"]: raise ValueError(f"Unknown module: {module}") base_table = meta["modules"][module]["base_table"] joins = [] joined_tables = set() where_clauses = [] # ---------- SELECT ---------- select_fields = plan.get("select", []) if select_fields: select_columns = [] for f in select_fields: field = resolve_field(f, module) select_columns.append( f"{field['table']}.{field['column']}" ) select_sql = ", ".join(select_columns) else: select_sql = f"{base_table}.*" # ---------- FILTERS ---------- for f in plan.get("filters", []): field = resolve_field(f["field"], module) table = field["table"] column = field["column"] if table != base_table and table not in joined_tables: join_steps = resolve_join_path(base_table, table) joins.append(build_join_sql(base_table, join_steps)) joined_tables.add(table) sql_op, sql_value = resolve_operator(f["operator"], f["value"]) where_clauses.append( f"{table}.{column} {sql_op} {sql_value}" ) # ---------- FINAL SQL ---------- sql = f""" SELECT {select_sql} FROM {base_table} {' '.join(joins)} WHERE {' AND '.join(where_clauses)} LIMIT 100 """ return sql.strip() # ========================= # VALIDATION # ========================= def validate_sql(sql): sql = sql.lower() if not sql.startswith("select"): raise ValueError("Only SELECT allowed") forbidden = ["drop", "delete", "update", "insert", "truncate"] if any(x in sql for x in forbidden): raise ValueError("Unsafe SQL") return sql # ========================= # MAIN ENTRY POINT # ========================= def run(question): plan = parse_intent(question) sql = build_sql(plan) sql = validate_sql(sql) return { "query_plan": plan, "sql": sql }