Text_to_sql / engine.py
bhavika24's picture
Upload engine.py
d25981d verified
raw
history blame
6.94 kB
import json
import os
from functools import lru_cache
from openai import OpenAI
from datetime import datetime
# =========================
# CONFIG
# =========================
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# =========================
# METADATA LOADING
# =========================
@lru_cache(maxsize=1)
def load_metadata():
with open("modules.json") as f:
modules = json.load(f)
with open("join_graph.json") as f:
joins = json.load(f)
with open("field_types.json") as f:
field_types = json.load(f)
with open("fields.json") as f:
fields = json.load(f)
return {
"modules": modules,
"joins": joins,
"field_types": field_types,
"fields": fields
}
def resolve_operator(op, value):
mapping = {
"equals": "=",
"not_equals": "!=",
"greater_than": ">",
"less_than": "<",
"greater_or_equal": ">=",
"less_or_equal": "<=",
"contains": "LIKE",
"starts_with": "LIKE",
"ends_with": "LIKE",
"in": "IN",
"not_in": "NOT IN"
}
if op not in mapping:
raise ValueError(f"Unsupported operator: {op}")
sql_op = mapping[op]
if op == "contains":
return sql_op, f"'%{value}%'"
if op == "starts_with":
return sql_op, f"'{value}%'"
if op == "ends_with":
return sql_op, f"'%{value}'"
if op in ("in", "not_in"):
if not isinstance(value, list):
raise ValueError("IN operator requires list")
return sql_op, f"({','.join(map(repr, value))})"
return sql_op, f"'{value}'"
# =========================
# JOIN RESOLUTION
# =========================
def resolve_join_path(start_table, end_table):
joins = load_metadata()["joins"]
for path in joins.values():
if path["start_table"] == start_table and path["end_table"] == end_table:
return path["steps"]
raise ValueError(
f"No join path found from {start_table} to {end_table}"
)
FIELD_ALIASES = {
"join_date": "start_date",
"joining_date": "start_date",
"joined": "start_date",
"hire_date": "start_date"
}
def resolve_field(field_name, module):
meta = load_metadata()
fields = meta["fields"]
# ๐Ÿ”น Normalize field name
field_name = field_name.lower().strip()
field_name = FIELD_ALIASES.get(field_name, field_name)
# ๐Ÿ”น Validate existence
if field_name not in fields:
raise ValueError(f"Unknown field: {field_name}")
field = fields[field_name]
# ๐Ÿ”น Validate module
if field["module"] != module:
raise ValueError(
f"Field '{field_name}' does not belong to module '{module}'"
)
# ๐Ÿ”น Validate mapping
if "table" not in field or "column" not in field:
raise ValueError(
f"Field '{field_name}' is missing table/column mapping"
)
return field
def build_join_sql(base_table, steps):
sql = []
prev_alias = base_table # alias == table name
for step in steps:
alias = step["alias"]
sql.append(
f"{step['join_type'].upper()} JOIN {step['table']} {alias} "
f"ON {prev_alias}.{step['base_column']} = {alias}.{step['foreign_column']}"
)
prev_alias = alias
return "\n".join(sql)
# =========================
# INTENT PARSING (LLM)
# =========================
def parse_intent(question):
meta = load_metadata()
# โœ… Build schema safely (skip empty modules)
schema_description = "\n".join([
f"{module}: {', '.join(fields)}"
for module in meta["modules"]
if (fields := [
f for f in meta["fields"]
if meta["fields"][f]["module"] == module
])
])
prompt = f"""
You are a SQL query planner.
You MUST only use fields listed below.
If a field does not exist, choose the closest valid field.
Do NOT invent column names.
Available schema:
{schema_description}
Extract:
- module
- filters (field, operator, value)
- selected fields
Return ONLY valid JSON.
User question:
{question}
"""
res = client.chat.completions.create(
model="gpt-4.1-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
# โœ… Safe JSON parsing
try:
return json.loads(res.choices[0].message.content)
except json.JSONDecodeError:
raise ValueError("LLM returned invalid JSON")
# =========================
# SQL GENERATOR
# =========================
def build_sql(plan):
meta = load_metadata()
module = plan["module"]
if module not in meta["modules"]:
raise ValueError(f"Unknown module: {module}")
base_table = meta["modules"][module]["base_table"]
joins = []
joined_tables = set()
where_clauses = []
# ---------- SELECT ----------
select_fields = plan.get("select", [])
if select_fields:
select_columns = []
for f in select_fields:
field = resolve_field(f, module)
select_columns.append(
f"{field['table']}.{field['column']}"
)
select_sql = ", ".join(select_columns)
else:
select_sql = f"{base_table}.*"
# ---------- FILTERS ----------
for f in plan.get("filters", []):
field = resolve_field(f["field"], module)
table = field["table"]
column = field["column"]
if table != base_table and table not in joined_tables:
join_steps = resolve_join_path(base_table, table)
joins.append(build_join_sql(base_table, join_steps))
joined_tables.add(table)
sql_op, sql_value = resolve_operator(f["operator"], f["value"])
where_clauses.append(
f"{table}.{column} {sql_op} {sql_value}"
)
# ---------- FINAL SQL ----------
sql = f"""
SELECT {select_sql}
FROM {base_table}
{' '.join(joins)}
WHERE {' AND '.join(where_clauses)}
LIMIT 100
"""
return sql.strip()
# =========================
# VALIDATION
# =========================
def validate_sql(sql):
sql = sql.lower()
if not sql.startswith("select"):
raise ValueError("Only SELECT allowed")
forbidden = ["drop", "delete", "update", "insert", "truncate"]
if any(x in sql for x in forbidden):
raise ValueError("Unsafe SQL")
return sql
# =========================
# MAIN ENTRY POINT
# =========================
def run(question):
plan = parse_intent(question)
sql = build_sql(plan)
sql = validate_sql(sql)
return {
"query_plan": plan,
"sql": sql
}