File size: 3,417 Bytes
8778707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from env.environment import SQLDebuggerEnvironment
from env.models import Action, ActionType

env = SQLDebuggerEnvironment()

# Test 1: state() before reset — must not crash
s = env.state()
print(f"State before reset: initialized={s.initialized}")

# Test 2: reset()
obs = env.reset(difficulty="easy")
print(f"Reset OK: task_id={obs.task_id}, difficulty={obs.difficulty}")
print(f"Context keys: {list(obs.current_context.keys())}")
print(f"Ground truth NOT in context: {'fixed_query' not in obs.current_context}")

# Test 3: step() identify_error
action1 = Action(
    action_type=ActionType.IDENTIFY_ERROR,
    payload={"error_location": "SELECT clause", "error_type": "syntax", "explanation": "Missing commas"}
)
resp1 = env.step(action1)
print(f"Step 1: reward={resp1.reward.score}, done={resp1.done}, step={resp1.observation.step_count}")

# Test 4: step() request_hint
action2 = Action(action_type=ActionType.REQUEST_HINT, payload={"hint_type": "location"})
resp2 = env.step(action2)
print(f"Step 2 hint: reward={resp2.reward.score}, hints_used={resp2.observation.hints_used}")
print(f"Hint in context: {'last_hint' in resp2.observation.current_context}")

# Test 5: step() submit_answer
obs = env.reset(difficulty="easy", task_id="easy_001")
action3 = Action(
    action_type=ActionType.SUBMIT_ANSWER,
    payload={
        "fixed_query": "SELECT id, name, email FROM users WHERE active = 1",
        "explanation": "Added missing commas between column names in SELECT clause",
        "error_type": "syntax",
        "error_location": "SELECT clause",
        "confidence": 0.95
    }
)
resp3 = env.step(action3)
print(f"Submit answer: reward={resp3.reward.score}, done={resp3.done}")

# Test 6: step after done — must not crash
resp4 = env.step(action3)
print(f"Step after done: done={resp4.done}, feedback='{resp4.reward.feedback}'")

# Test 7: null action
obs = env.reset(difficulty="easy")
resp5 = env.step(None)
print(f"Null action: reward={resp5.reward.score}, done={resp5.done}")

# Test 8: reset mid-episode clears state
obs = env.reset(difficulty="medium")
print(f"Mid-episode reset: new task={obs.task_id}, step_count={obs.step_count}")

# Test 9: full episode 10 steps
obs = env.reset(difficulty="hard")
print(f"Hard episode started: {obs.task_id}")
actions = [
    Action(action_type=ActionType.IDENTIFY_ERROR, payload={"error_location": "SELECT clause", "error_type": "performance"}),
    Action(action_type=ActionType.EXPLAIN_ISSUE, payload={"explanation": "N+1 correlated subqueries cause multiple DB hits per row", "impact": "O(n) queries", "root_cause": "Subquery per user"}),
    Action(action_type=ActionType.OPTIMIZE_QUERY, payload={
        "optimized_query": "SELECT u.id, u.name, COUNT(o.id) as order_count, COALESCE(SUM(o.total), 0) as total_spent FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name",
        "optimization_type": "Replace N+1 correlated subqueries with LEFT JOIN aggregation",
        "explanation": "Single query replaces N+1 pattern",
        "root_cause": "Correlated subqueries in SELECT",
        "expected_improvement": "99% reduction in DB round trips",
        "confidence": 0.9
    }),
]
total = 0.0
for i, a in enumerate(actions):
    r = env.step(a)
    total += r.reward.score
    print(f"  Hard step {i+1}: reward={r.reward.score}, done={r.done}")
print(f"Hard episode total reward: {round(total,4)}")

print("environment.py OK")