File size: 5,531 Bytes
c15d346
 
 
 
 
 
 
7841be7
c15d346
 
 
 
 
 
 
 
 
7841be7
 
 
 
 
 
c15d346
 
 
 
 
 
 
 
 
 
7841be7
 
c15d346
 
7841be7
 
 
 
c15d346
 
 
7841be7
c15d346
 
 
7841be7
 
 
c15d346
7841be7
 
 
 
 
 
c15d346
 
7841be7
 
 
c15d346
7841be7
c15d346
7841be7
 
 
c15d346
7841be7
 
 
c15d346
 
 
 
 
 
 
 
 
 
 
 
7841be7
c15d346
 
 
7841be7
c15d346
7841be7
 
 
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7841be7
 
c15d346
7841be7
 
 
c15d346
7841be7
c15d346
 
 
7841be7
 
 
 
 
c15d346
 
 
7841be7
 
 
 
 
 
 
 
 
 
c15d346
 
 
7841be7
 
 
 
 
 
 
c15d346
7841be7
 
 
 
c15d346
7841be7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
env.py β€” SQLOptimEnv: Core OpenEnv Environment Class
"""

from typing import Any, Dict, Optional

from executor import get_executor
from graders import grade
from leaderboard import record as lb_record
from models import (
    Action,
    EnvironmentState,
    Observation,
    Reward,
    StepResult,
)
from tasks import TASKS


class SQLOptimEnv:
    """
    OpenEnv-compliant environment for SQL Query Optimization.

    The agent receives a SQL query + schema context, emits an Action
    containing a list of optimization suggestions AND a rewritten
    optimized_query.  The environment executes both queries against
    real DuckDB data, measures the actual speedup, and checks
    result correctness β€” all fed into the reward function.

    Multi-step:
      β€’ issues_found_so_far accumulates flagged issue types.
      β€’ last_execution carries execution metrics back to the agent
        so it can refine the optimized_query in subsequent steps.
    """

    def __init__(self) -> None:
        self._task_data: Optional[Dict[str, Any]] = None
        self._step_count: int = 0
        self._done: bool = False
        self._cumulative_reward: float = 0.0
        self._issues_found: list = []
        self._last_execution: Optional[Dict[str, Any]] = None

    # ── OpenEnv interface ─────────────────────────────────────────────

    def reset(
        self, task_id: str = "task_1_basic_antipatterns"
    ) -> Observation:
        if task_id not in TASKS:
            raise ValueError(
                f"Unknown task_id '{task_id}'. "
                f"Valid: {list(TASKS.keys())}"
            )
        self._task_data = TASKS[task_id]
        self._step_count = 0
        self._done = False
        self._cumulative_reward = 0.0
        self._issues_found = []
        self._last_execution = None
        return self._make_obs()

    def step(self, action: Action) -> StepResult:
        if self._task_data is None:
            raise RuntimeError("No active episode β€” call reset() first.")
        if self._done:
            raise RuntimeError("Episode finished β€” call reset() to start a new one.")

        self._step_count += 1

        # Grade (runs DuckDB internally)
        reward: Reward = grade(self._task_data, action)
        self._cumulative_reward += reward.score

        # Extract execution info from grader feedback for next obs
        opt_q = (action.optimized_query or "").strip()
        if opt_q:
            try:
                ex = get_executor()
                self._last_execution = ex.compare(
                    self._task_data["sql_query"], opt_q
                )
            except Exception:
                self._last_execution = None

        # Track issue types for progressive context
        for s in action.suggestions:
            itype = s.get("issue_type", "")
            if itype and itype not in self._issues_found:
                self._issues_found.append(itype)

        max_steps: int = self._task_data["max_steps"]
        done = self._step_count >= max_steps or reward.score >= 0.95
        self._done = done

        # Update leaderboard
        speedup = (
            self._last_execution.get("speedup", 1.0)
            if self._last_execution else 1.0
        )
        results_match = (
            self._last_execution.get("results_match", False)
            if self._last_execution else False
        )
        lb_record(
            task_id=self._task_data["task_id"],
            speedup=speedup,
            score=reward.score,
            results_match=results_match,
            steps=self._step_count,
        )

        return StepResult(
            observation=self._make_obs(),
            reward=reward,
            done=done,
            info={
                "step":              self._step_count,
                "cumulative_reward": round(self._cumulative_reward, 4),
                "issues_found":      len(self._issues_found),
                "execution":         self._last_execution,
            },
        )

    def state(self) -> EnvironmentState:
        if self._task_data is None:
            return EnvironmentState(
                task_id="none", step_count=0, max_steps=0,
                episode_done=True, cumulative_reward=0.0,
                current_task="No active episode",
            )
        return EnvironmentState(
            task_id=self._task_data["task_id"],
            step_count=self._step_count,
            max_steps=self._task_data["max_steps"],
            episode_done=self._done,
            cumulative_reward=round(self._cumulative_reward, 4),
            current_task=self._task_data["task_name"],
        )

    # ── Internal ──────────────────────────────────────────────────────

    def _make_obs(self) -> Observation:
        d = self._task_data
        return Observation(
            task_id=d["task_id"],
            task_name=d["task_name"],
            task_description=d["task_description"],
            sql_query=d["sql_query"],
            schema_info=d["schema_info"],
            dialect=d.get("dialect", "duckdb/postgresql"),
            difficulty=d["difficulty"],
            step_count=self._step_count,
            max_steps=d["max_steps"],
            issues_found_so_far=list(self._issues_found),
            last_execution=self._last_execution,
        )