jaivardhan2409's picture
Upload folder using huggingface_hub
e4c32ce verified
import sqlglot
from sqlglot import exp
from typing import Dict, Any, Tuple
TASKS = {
1: {
"name": "fix-broken-join",
"difficulty": "easy",
"schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);",
"hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.",
"initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;",
"max_steps": 3,
},
2: {
"name": "eliminate-n-plus-one",
"difficulty": "medium",
"schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);",
"hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.",
"initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');",
"max_steps": 4,
},
3: {
"name": "full-optimization",
"difficulty": "hard",
"schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);",
"hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.",
"initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';",
"max_steps": 5,
}
}
def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
feedback = []
breakdown = {}
joins = list(parsed.find_all(exp.Join))
if not joins:
return 0.0, {"missing_join": 1.0}, "No JOIN found in the query."
join = joins[0]
if join.args.get("on"):
score += 1.0
breakdown["has_on_clause"] = 1.0
feedback.append("Successfully added the ON clause.")
else:
breakdown["has_on_clause"] = 0.0
feedback.append("The JOIN is still missing an ON clause.")
return score, breakdown, " ".join(feedback)
def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
breakdown = {}
feedback = []
subqueries = list(parsed.find_all(exp.Subquery))
if not subqueries and not list(parsed.find_all(exp.In)):
score += 0.5
breakdown["removed_correlated_subquery"] = 0.5
feedback.append("Removed correlated subquery.")
else:
breakdown["removed_correlated_subquery"] = 0.0
feedback.append("Correlated subquery still present.")
joins = list(parsed.find_all(exp.Join))
if joins:
score += 0.5
breakdown["added_join"] = 0.5
feedback.append("Added JOIN successfully.")
else:
breakdown["added_join"] = 0.0
feedback.append("Missing JOIN.")
return score, breakdown, " ".join(feedback)
def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
try:
parsed = sqlglot.parse_one(rewritten_query, read="postgres")
except Exception as e:
return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
score = 0.0
breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0}
feedback = []
if not parsed.args.get("distinct"):
score += 0.25
breakdown["no_distinct"] = 0.25
feedback.append("Removed redundant DISTINCT.")
stars = list(parsed.find_all(exp.Star))
if not stars:
score += 0.25
breakdown["no_select_star"] = 0.25
feedback.append("Replaced SELECT * with explicit columns.")
casts = list(parsed.find_all(exp.Cast))
cast_on_date = False
for c in casts:
this = c.args.get("this")
if isinstance(this, exp.Column) and this.name.lower() == "sale_date":
cast_on_date = True
if not cast_on_date:
score += 0.25
breakdown["fixed_cast"] = 0.25
feedback.append("Fixed implicit type cast on sale_date.")
if "INDEX" in rewritten_query.upper():
score += 0.25
breakdown["has_index_hint"] = 0.25
feedback.append("Added index hint.")
return score, breakdown, " ".join(feedback)
def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
if task_id == 1:
return grade_task_1(rewritten_query)
elif task_id == 2:
return grade_task_2(rewritten_query)
elif task_id == 3:
return grade_task_3(rewritten_query)
return 0.0, {}, "Unknown task."
def get_task(task_id: int) -> Dict[str, Any]:
return TASKS.get(task_id)