import sqlglot from sqlglot import exp from typing import Dict, Any, Tuple TASKS = { 1: { "name": "fix-broken-join", "difficulty": "easy", "schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);", "hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.", "initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;", "max_steps": 3, }, 2: { "name": "eliminate-n-plus-one", "difficulty": "medium", "schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);", "hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.", "initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');", "max_steps": 4, }, 3: { "name": "full-optimization", "difficulty": "hard", "schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);", "hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.", "initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';", "max_steps": 5, } } def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: try: parsed = sqlglot.parse_one(rewritten_query, read="postgres") except Exception as e: return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" score = 0.0 feedback = [] breakdown = {} joins = list(parsed.find_all(exp.Join)) if not joins: return 0.0, {"missing_join": 1.0}, "No JOIN found in the query." join = joins[0] if join.args.get("on"): score += 1.0 breakdown["has_on_clause"] = 1.0 feedback.append("Successfully added the ON clause.") else: breakdown["has_on_clause"] = 0.0 feedback.append("The JOIN is still missing an ON clause.") return score, breakdown, " ".join(feedback) def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: try: parsed = sqlglot.parse_one(rewritten_query, read="postgres") except Exception as e: return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" score = 0.0 breakdown = {} feedback = [] subqueries = list(parsed.find_all(exp.Subquery)) if not subqueries and not list(parsed.find_all(exp.In)): score += 0.5 breakdown["removed_correlated_subquery"] = 0.5 feedback.append("Removed correlated subquery.") else: breakdown["removed_correlated_subquery"] = 0.0 feedback.append("Correlated subquery still present.") joins = list(parsed.find_all(exp.Join)) if joins: score += 0.5 breakdown["added_join"] = 0.5 feedback.append("Added JOIN successfully.") else: breakdown["added_join"] = 0.0 feedback.append("Missing JOIN.") return score, breakdown, " ".join(feedback) def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: try: parsed = sqlglot.parse_one(rewritten_query, read="postgres") except Exception as e: return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" score = 0.0 breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0} feedback = [] if not parsed.args.get("distinct"): score += 0.25 breakdown["no_distinct"] = 0.25 feedback.append("Removed redundant DISTINCT.") stars = list(parsed.find_all(exp.Star)) if not stars: score += 0.25 breakdown["no_select_star"] = 0.25 feedback.append("Replaced SELECT * with explicit columns.") casts = list(parsed.find_all(exp.Cast)) cast_on_date = False for c in casts: this = c.args.get("this") if isinstance(this, exp.Column) and this.name.lower() == "sale_date": cast_on_date = True if not cast_on_date: score += 0.25 breakdown["fixed_cast"] = 0.25 feedback.append("Fixed implicit type cast on sale_date.") if "INDEX" in rewritten_query.upper(): score += 0.25 breakdown["has_index_hint"] = 0.25 feedback.append("Added index hint.") return score, breakdown, " ".join(feedback) def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]: if task_id == 1: return grade_task_1(rewritten_query) elif task_id == 2: return grade_task_2(rewritten_query) elif task_id == 3: return grade_task_3(rewritten_query) return 0.0, {}, "Unknown task." def get_task(task_id: int) -> Dict[str, Any]: return TASKS.get(task_id)