Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from functools import lru_cache | |
| from openai import OpenAI | |
| from datetime import datetime | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # ========================= | |
| # METADATA LOADING | |
| # ========================= | |
| def load_metadata(): | |
| with open("modules.json") as f: | |
| modules = json.load(f) | |
| with open("join_graph.json") as f: | |
| joins = json.load(f) | |
| with open("field_types.json") as f: | |
| field_types = json.load(f) | |
| with open("fields.json") as f: | |
| fields = json.load(f) | |
| return { | |
| "modules": modules, | |
| "joins": joins, | |
| "field_types": field_types, | |
| "fields": fields | |
| } | |
| def resolve_operator(op, value): | |
| mapping = { | |
| "equals": "=", | |
| "not_equals": "!=", | |
| "greater_than": ">", | |
| "less_than": "<", | |
| "greater_or_equal": ">=", | |
| "less_or_equal": "<=", | |
| "contains": "LIKE", | |
| "starts_with": "LIKE", | |
| "ends_with": "LIKE", | |
| "in": "IN", | |
| "not_in": "NOT IN" | |
| } | |
| if op not in mapping: | |
| raise ValueError(f"Unsupported operator: {op}") | |
| sql_op = mapping[op] | |
| if op == "contains": | |
| return sql_op, f"'%{value}%'" | |
| if op == "starts_with": | |
| return sql_op, f"'{value}%'" | |
| if op == "ends_with": | |
| return sql_op, f"'%{value}'" | |
| if op in ("in", "not_in"): | |
| if not isinstance(value, list): | |
| raise ValueError("IN operator requires list") | |
| return sql_op, f"({','.join(map(repr, value))})" | |
| return sql_op, f"'{value}'" | |
| # ========================= | |
| # JOIN RESOLUTION | |
| # ========================= | |
| def resolve_join_path(start_table, end_table): | |
| joins = load_metadata()["joins"] | |
| for path in joins.values(): | |
| if path["start_table"] == start_table and path["end_table"] == end_table: | |
| return path["steps"] | |
| raise ValueError( | |
| f"No join path found from {start_table} to {end_table}" | |
| ) | |
| FIELD_ALIASES = { | |
| "join_date": "start_date", | |
| "joining_date": "start_date", | |
| "joined": "start_date", | |
| "hire_date": "start_date" | |
| } | |
| def resolve_field(field_name, module): | |
| meta = load_metadata() | |
| fields = meta["fields"] | |
| # ๐น Normalize field name | |
| field_name = field_name.lower().strip() | |
| field_name = FIELD_ALIASES.get(field_name, field_name) | |
| # ๐น Validate existence | |
| if field_name not in fields: | |
| raise ValueError(f"Unknown field: {field_name}") | |
| field = fields[field_name] | |
| # ๐น Validate module | |
| if field["module"] != module: | |
| raise ValueError( | |
| f"Field '{field_name}' does not belong to module '{module}'" | |
| ) | |
| # ๐น Validate mapping | |
| if "table" not in field or "column" not in field: | |
| raise ValueError( | |
| f"Field '{field_name}' is missing table/column mapping" | |
| ) | |
| return field | |
| def build_join_sql(base_table, steps): | |
| sql = [] | |
| prev_alias = base_table # alias == table name | |
| for step in steps: | |
| alias = step["alias"] | |
| sql.append( | |
| f"{step['join_type'].upper()} JOIN {step['table']} {alias} " | |
| f"ON {prev_alias}.{step['base_column']} = {alias}.{step['foreign_column']}" | |
| ) | |
| prev_alias = alias | |
| return "\n".join(sql) | |
| # ========================= | |
| # INTENT PARSING (LLM) | |
| # ========================= | |
| def parse_intent(question): | |
| meta = load_metadata() | |
| # โ Build schema safely (skip empty modules) | |
| schema_description = "\n".join([ | |
| f"{module}: {', '.join(fields)}" | |
| for module in meta["modules"] | |
| if (fields := [ | |
| f for f in meta["fields"] | |
| if meta["fields"][f]["module"] == module | |
| ]) | |
| ]) | |
| prompt = f""" | |
| You are a SQL query planner. | |
| You MUST only use fields listed below. | |
| If a field does not exist, choose the closest valid field. | |
| Do NOT invent column names. | |
| Available schema: | |
| {schema_description} | |
| Extract: | |
| - module | |
| - filters (field, operator, value) | |
| - selected fields | |
| Return ONLY valid JSON. | |
| User question: | |
| {question} | |
| """ | |
| res = client.chat.completions.create( | |
| model="gpt-4.1-mini", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| # โ Safe JSON parsing | |
| try: | |
| return json.loads(res.choices[0].message.content) | |
| except json.JSONDecodeError: | |
| raise ValueError("LLM returned invalid JSON") | |
| # ========================= | |
| # SQL GENERATOR | |
| # ========================= | |
| def build_sql(plan): | |
| meta = load_metadata() | |
| module = plan["module"] | |
| if module not in meta["modules"]: | |
| raise ValueError(f"Unknown module: {module}") | |
| base_table = meta["modules"][module]["base_table"] | |
| joins = [] | |
| joined_tables = set() | |
| where_clauses = [] | |
| # ---------- SELECT ---------- | |
| select_fields = plan.get("select", []) | |
| if select_fields: | |
| select_columns = [] | |
| for f in select_fields: | |
| field = resolve_field(f, module) | |
| select_columns.append( | |
| f"{field['table']}.{field['column']}" | |
| ) | |
| select_sql = ", ".join(select_columns) | |
| else: | |
| select_sql = f"{base_table}.*" | |
| # ---------- FILTERS ---------- | |
| for f in plan.get("filters", []): | |
| field = resolve_field(f["field"], module) | |
| table = field["table"] | |
| column = field["column"] | |
| if table != base_table and table not in joined_tables: | |
| join_steps = resolve_join_path(base_table, table) | |
| joins.append(build_join_sql(base_table, join_steps)) | |
| joined_tables.add(table) | |
| sql_op, sql_value = resolve_operator(f["operator"], f["value"]) | |
| where_clauses.append( | |
| f"{table}.{column} {sql_op} {sql_value}" | |
| ) | |
| # ---------- FINAL SQL ---------- | |
| sql = f""" | |
| SELECT {select_sql} | |
| FROM {base_table} | |
| {' '.join(joins)} | |
| WHERE {' AND '.join(where_clauses)} | |
| LIMIT 100 | |
| """ | |
| return sql.strip() | |
| # ========================= | |
| # VALIDATION | |
| # ========================= | |
| def validate_sql(sql): | |
| sql = sql.lower() | |
| if not sql.startswith("select"): | |
| raise ValueError("Only SELECT allowed") | |
| forbidden = ["drop", "delete", "update", "insert", "truncate"] | |
| if any(x in sql for x in forbidden): | |
| raise ValueError("Unsafe SQL") | |
| return sql | |
| # ========================= | |
| # MAIN ENTRY POINT | |
| # ========================= | |
| def run(question): | |
| plan = parse_intent(question) | |
| sql = build_sql(plan) | |
| sql = validate_sql(sql) | |
| return { | |
| "query_plan": plan, | |
| "sql": sql | |
| } | |