Spaces:
Sleeping
Sleeping
| import sqlglot | |
| from sqlglot import exp | |
| from typing import Dict, Any, Tuple | |
| TASKS = { | |
| 1: { | |
| "name": "fix-broken-join", | |
| "difficulty": "easy", | |
| "schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);", | |
| "hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.", | |
| "initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;", | |
| "max_steps": 3, | |
| }, | |
| 2: { | |
| "name": "eliminate-n-plus-one", | |
| "difficulty": "medium", | |
| "schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);", | |
| "hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.", | |
| "initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');", | |
| "max_steps": 4, | |
| }, | |
| 3: { | |
| "name": "full-optimization", | |
| "difficulty": "hard", | |
| "schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);", | |
| "hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.", | |
| "initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';", | |
| "max_steps": 5, | |
| } | |
| } | |
| def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: | |
| try: | |
| parsed = sqlglot.parse_one(rewritten_query, read="postgres") | |
| except Exception as e: | |
| return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" | |
| score = 0.0 | |
| feedback = [] | |
| breakdown = {} | |
| joins = list(parsed.find_all(exp.Join)) | |
| if not joins: | |
| return 0.0, {"missing_join": 1.0}, "No JOIN found in the query." | |
| join = joins[0] | |
| if join.args.get("on"): | |
| score += 1.0 | |
| breakdown["has_on_clause"] = 1.0 | |
| feedback.append("Successfully added the ON clause.") | |
| else: | |
| breakdown["has_on_clause"] = 0.0 | |
| feedback.append("The JOIN is still missing an ON clause.") | |
| return score, breakdown, " ".join(feedback) | |
| def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: | |
| try: | |
| parsed = sqlglot.parse_one(rewritten_query, read="postgres") | |
| except Exception as e: | |
| return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" | |
| score = 0.0 | |
| breakdown = {} | |
| feedback = [] | |
| subqueries = list(parsed.find_all(exp.Subquery)) | |
| if not subqueries and not list(parsed.find_all(exp.In)): | |
| score += 0.5 | |
| breakdown["removed_correlated_subquery"] = 0.5 | |
| feedback.append("Removed correlated subquery.") | |
| else: | |
| breakdown["removed_correlated_subquery"] = 0.0 | |
| feedback.append("Correlated subquery still present.") | |
| joins = list(parsed.find_all(exp.Join)) | |
| if joins: | |
| score += 0.5 | |
| breakdown["added_join"] = 0.5 | |
| feedback.append("Added JOIN successfully.") | |
| else: | |
| breakdown["added_join"] = 0.0 | |
| feedback.append("Missing JOIN.") | |
| return score, breakdown, " ".join(feedback) | |
| def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]: | |
| try: | |
| parsed = sqlglot.parse_one(rewritten_query, read="postgres") | |
| except Exception as e: | |
| return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}" | |
| score = 0.0 | |
| breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0} | |
| feedback = [] | |
| if not parsed.args.get("distinct"): | |
| score += 0.25 | |
| breakdown["no_distinct"] = 0.25 | |
| feedback.append("Removed redundant DISTINCT.") | |
| stars = list(parsed.find_all(exp.Star)) | |
| if not stars: | |
| score += 0.25 | |
| breakdown["no_select_star"] = 0.25 | |
| feedback.append("Replaced SELECT * with explicit columns.") | |
| casts = list(parsed.find_all(exp.Cast)) | |
| cast_on_date = False | |
| for c in casts: | |
| this = c.args.get("this") | |
| if isinstance(this, exp.Column) and this.name.lower() == "sale_date": | |
| cast_on_date = True | |
| if not cast_on_date: | |
| score += 0.25 | |
| breakdown["fixed_cast"] = 0.25 | |
| feedback.append("Fixed implicit type cast on sale_date.") | |
| if "INDEX" in rewritten_query.upper(): | |
| score += 0.25 | |
| breakdown["has_index_hint"] = 0.25 | |
| feedback.append("Added index hint.") | |
| return score, breakdown, " ".join(feedback) | |
| def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]: | |
| if task_id == 1: | |
| return grade_task_1(rewritten_query) | |
| elif task_id == 2: | |
| return grade_task_2(rewritten_query) | |
| elif task_id == 3: | |
| return grade_task_3(rewritten_query) | |
| return 0.0, {}, "Unknown task." | |
| def get_task(task_id: int) -> Dict[str, Any]: | |
| return TASKS.get(task_id) | |