File size: 5,184 Bytes
e4c32ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sqlglot
from sqlglot import exp
from typing import Dict, Any, Tuple

TASKS = {
    1: {
        "name": "fix-broken-join",
        "difficulty": "easy",
        "schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);",
        "hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.",
        "initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;",
        "max_steps": 3,
    },
    2: {
        "name": "eliminate-n-plus-one",
        "difficulty": "medium",
        "schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);",
        "hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.",
        "initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');",
        "max_steps": 4,
    },
    3: {
        "name": "full-optimization",
        "difficulty": "hard",
        "schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);",
        "hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.",
        "initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';",
        "max_steps": 5,
    }
}

def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
    try:
        parsed = sqlglot.parse_one(rewritten_query, read="postgres")
    except Exception as e:
        return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
    
    score = 0.0
    feedback = []
    breakdown = {}
    
    joins = list(parsed.find_all(exp.Join))
    if not joins:
        return 0.0, {"missing_join": 1.0}, "No JOIN found in the query."
    
    join = joins[0]
    if join.args.get("on"):
        score += 1.0
        breakdown["has_on_clause"] = 1.0
        feedback.append("Successfully added the ON clause.")
    else:
        breakdown["has_on_clause"] = 0.0
        feedback.append("The JOIN is still missing an ON clause.")
        
    return score, breakdown, " ".join(feedback)

def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
    try:
        parsed = sqlglot.parse_one(rewritten_query, read="postgres")
    except Exception as e:
        return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
    
    score = 0.0
    breakdown = {}
    feedback = []
    
    subqueries = list(parsed.find_all(exp.Subquery))
    if not subqueries and not list(parsed.find_all(exp.In)):
        score += 0.5
        breakdown["removed_correlated_subquery"] = 0.5
        feedback.append("Removed correlated subquery.")
    else:
        breakdown["removed_correlated_subquery"] = 0.0
        feedback.append("Correlated subquery still present.")
        
    joins = list(parsed.find_all(exp.Join))
    if joins:
        score += 0.5
        breakdown["added_join"] = 0.5
        feedback.append("Added JOIN successfully.")
    else:
        breakdown["added_join"] = 0.0
        feedback.append("Missing JOIN.")
        
    return score, breakdown, " ".join(feedback)

def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
    try:
        parsed = sqlglot.parse_one(rewritten_query, read="postgres")
    except Exception as e:
        return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
        
    score = 0.0
    breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0}
    feedback = []
    
    if not parsed.args.get("distinct"):
        score += 0.25
        breakdown["no_distinct"] = 0.25
        feedback.append("Removed redundant DISTINCT.")
    
    stars = list(parsed.find_all(exp.Star))
    if not stars:
        score += 0.25
        breakdown["no_select_star"] = 0.25
        feedback.append("Replaced SELECT * with explicit columns.")
        
    casts = list(parsed.find_all(exp.Cast))
    cast_on_date = False
    for c in casts:
        this = c.args.get("this")
        if isinstance(this, exp.Column) and this.name.lower() == "sale_date":
            cast_on_date = True
    
    if not cast_on_date:
        score += 0.25
        breakdown["fixed_cast"] = 0.25
        feedback.append("Fixed implicit type cast on sale_date.")
        
    if "INDEX" in rewritten_query.upper():
        score += 0.25
        breakdown["has_index_hint"] = 0.25
        feedback.append("Added index hint.")
        
    return score, breakdown, " ".join(feedback)

def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
    if task_id == 1:
        return grade_task_1(rewritten_query)
    elif task_id == 2:
        return grade_task_2(rewritten_query)
    elif task_id == 3:
        return grade_task_3(rewritten_query)
    return 0.0, {}, "Unknown task."

def get_task(task_id: int) -> Dict[str, Any]:
    return TASKS.get(task_id)