test-nl-to-sql / nl_to_sql.py
amirwesthoff's picture
Add application file
aa50339
Raw
History Blame Contribute Delete
5.52 kB
import re
import pandas as pd
MONTHS = {
"january": "01", "february": "02", "march": "03", "april": "04", "may": "05", "june": "06",
"july": "07", "august": "08", "september": "09", "october": "10", "november": "11", "december": "12"
}
def normalize(text: str) -> str:
return re.sub(r"\s+", " ", text.strip().lower())
def extract_top_n(q: str, default=5):
m = re.search(r"top\s+(\d+)", q)
return int(m.group(1)) if m else default
def month_filter(q: str):
for name, nr in MONTHS.items():
if name in q:
year_match = re.search(r"(20\d{2})", q)
year = year_match.group(1) if year_match else "2026"
return f"WHERE substr(o.order_date, 1, 7) = '{year}-{nr}'"
return ""
def category_filter(q: str):
categories = ["electronics", "furniture", "accessories", "stationery"]
for c in categories:
if c in q:
return c.title()
return None
def parse_question_to_sql(question: str):
q = normalize(question)
notes = []
if "schema" in q or "tables" in q:
sql = "SELECT 'customers' AS table_name UNION ALL SELECT 'products' UNION ALL SELECT 'orders' UNION ALL SELECT 'order_items';"
notes.append("Interpreted as a schema discovery question.")
return sql, notes
if "show all customers" in q or ("customers" in q and ("list" in q or "show" in q or q == "customers")):
sql = "SELECT * FROM customers ORDER BY id;"
notes.append("Mapped to a simple customer listing.")
return sql, notes
if ("products" in q or "product" in q) and ("category" in q or "in the" in q or "list" in q or "show" in q):
category = category_filter(q)
if category:
sql = f"SELECT * FROM products WHERE category = '{category}' ORDER BY id;"
notes.append(f"Detected product category filter: {category}.")
return sql, notes
if "orders from" in q or ("orders" in q and any(m in q for m in MONTHS)):
filt = month_filter(q)
sql = f"SELECT o.*, c.name AS customer_name FROM orders o JOIN customers c ON c.id = o.customer_id {filt} ORDER BY o.order_date;"
notes.append("Detected an order date filter.")
return sql, notes
if "delayed" in q and ("how many" in q or "count" in q):
sql = "SELECT COUNT(*) AS delayed_orders FROM orders WHERE status = 'delayed';"
notes.append("Mapped to delayed order count.")
return sql, notes
if "revenue by country" in q or ("total revenue" in q and "country" in q):
sql = """
SELECT c.country,
ROUND(SUM(oi.quantity * oi.unit_price), 2) AS revenue
FROM order_items oi
JOIN orders o ON o.id = oi.order_id
JOIN customers c ON c.id = o.customer_id
GROUP BY c.country
ORDER BY revenue DESC;
""".strip()
notes.append("Revenue interpreted as SUM(quantity * unit_price).")
return sql, notes
if "revenue by month" in q or "sales by month" in q:
sql = """
SELECT substr(o.order_date, 1, 7) AS month,
ROUND(SUM(oi.quantity * oi.unit_price), 2) AS revenue
FROM order_items oi
JOIN orders o ON o.id = oi.order_id
GROUP BY substr(o.order_date, 1, 7)
ORDER BY month;
""".strip()
notes.append("Grouped by order month.")
return sql, notes
if "top" in q and "product" in q and "revenue" in q:
n = extract_top_n(q, default=5)
sql = f"""
SELECT p.name,
ROUND(SUM(oi.quantity * oi.unit_price), 2) AS revenue
FROM order_items oi
JOIN products p ON p.id = oi.product_id
GROUP BY p.name
ORDER BY revenue DESC
LIMIT {n};
""".strip()
notes.append(f"Detected ranking request with top {n}.")
return sql, notes
if "average order value by customer" in q or ("average" in q and "customer" in q and "order value" in q):
sql = """
WITH order_totals AS (
SELECT o.id AS order_id,
c.name AS customer_name,
SUM(oi.quantity * oi.unit_price) AS order_total
FROM orders o
JOIN customers c ON c.id = o.customer_id
JOIN order_items oi ON oi.order_id = o.id
GROUP BY o.id, c.name
)
SELECT customer_name,
ROUND(AVG(order_total), 2) AS avg_order_value
FROM order_totals
GROUP BY customer_name
ORDER BY avg_order_value DESC;
""".strip()
notes.append("Calculated average from per-order totals.")
return sql, notes
if "all orders" in q or ("orders" in q and ("list" in q or "show" in q)):
sql = "SELECT * FROM orders ORDER BY order_date;"
notes.append("Mapped to a full orders listing.")
return sql, notes
sql = "SELECT 'Sorry, I do not understand that question yet.' AS message;"
notes.append("No rule matched; returned a fallback query.")
return sql, notes
def format_answer(question: str, sql: str, df: pd.DataFrame, notes):
preview = df.head(20).to_markdown(index=False) if not df.empty else "(no rows returned)"
explanation = "\n".join(f"- {n}" for n in notes) if notes else "- No special interpretation notes."
return (
f"I translated your question into this SQL:\n\n"
f"```sql\n{sql}\n```\n\n"
f"Interpretation notes:\n{explanation}\n\n"
f"Result preview ({len(df)} row(s)):\n\n{preview}"
)