Spaces:
Sleeping
Sleeping
File size: 5,184 Bytes
e4c32ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import sqlglot
from sqlglot import exp
from typing import Dict, Any, Tuple
TASKS = {
1: {
"name": "fix-broken-join",
"difficulty": "easy",
"schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);",
"hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.",
"initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;",
"max_steps": 3,
},
2: {
"name": "eliminate-n-plus-one",
"difficulty": "medium",
"schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);",
"hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.",
"initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');",
"max_steps": 4,
},
3: {
"name": "full-optimization",
"difficulty": "hard",
"schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);",
"hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.",
"initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';",
"max_steps": 5,
}
}
def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
feedback = []
breakdown = {}
joins = list(parsed.find_all(exp.Join))
if not joins:
return 0.0, {"missing_join": 1.0}, "No JOIN found in the query."
join = joins[0]
if join.args.get("on"):
score += 1.0
breakdown["has_on_clause"] = 1.0
feedback.append("Successfully added the ON clause.")
else:
breakdown["has_on_clause"] = 0.0
feedback.append("The JOIN is still missing an ON clause.")
return score, breakdown, " ".join(feedback)
def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
breakdown = {}
feedback = []
subqueries = list(parsed.find_all(exp.Subquery))
if not subqueries and not list(parsed.find_all(exp.In)):
score += 0.5
breakdown["removed_correlated_subquery"] = 0.5
feedback.append("Removed correlated subquery.")
else:
breakdown["removed_correlated_subquery"] = 0.0
feedback.append("Correlated subquery still present.")
joins = list(parsed.find_all(exp.Join))
if joins:
score += 0.5
breakdown["added_join"] = 0.5
feedback.append("Added JOIN successfully.")
else:
breakdown["added_join"] = 0.0
feedback.append("Missing JOIN.")
return score, breakdown, " ".join(feedback)
def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0}
feedback = []
if not parsed.args.get("distinct"):
score += 0.25
breakdown["no_distinct"] = 0.25
feedback.append("Removed redundant DISTINCT.")
stars = list(parsed.find_all(exp.Star))
if not stars:
score += 0.25
breakdown["no_select_star"] = 0.25
feedback.append("Replaced SELECT * with explicit columns.")
casts = list(parsed.find_all(exp.Cast))
cast_on_date = False
for c in casts:
this = c.args.get("this")
if isinstance(this, exp.Column) and this.name.lower() == "sale_date":
cast_on_date = True
if not cast_on_date:
score += 0.25
breakdown["fixed_cast"] = 0.25
feedback.append("Fixed implicit type cast on sale_date.")
if "INDEX" in rewritten_query.upper():
score += 0.25
breakdown["has_index_hint"] = 0.25
feedback.append("Added index hint.")
return score, breakdown, " ".join(feedback)
def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
if task_id == 1:
return grade_task_1(rewritten_query)
elif task_id == 2:
return grade_task_2(rewritten_query)
elif task_id == 3:
return grade_task_3(rewritten_query)
return 0.0, {}, "Unknown task."
def get_task(task_id: int) -> Dict[str, Any]:
return TASKS.get(task_id)
|