File size: 7,937 Bytes
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c147
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c147
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c147
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
SQLDebugEnvironment β€” Gym-like RL environment for the SQL debug loop.

Lifecycle:
  1. env.reset(question)           β€” start new episode
  2. env.observe_error(error, sql) β€” classify error, build state
  3. env.select_action()           β€” bandit picks repair strategy
  4. env.get_repair_prompt(...)    β€” get specialized prompt for chosen action
  5. env.record_step(success)      β€” record outcome, compute reward
  6. Repeat 2-5 until success or max attempts
  7. env.end_episode(success)      β€” finalize, HER relabeling, bandit update

This module is a stateful singleton β€” one active episode at a time.
"""

from __future__ import annotations

import time
from typing import Optional

from rl.types import (
    RLState,
    RepairAction,
    ErrorClass,
    EpisodeStep,
    RLMetrics,
    featurize,
    REPAIR_ACTION_NAMES,
    ERROR_CLASS_NAMES,
)
from rl.error_classifier import classify_error, extract_offending_token
from rl.grader import GraderInput, compute_reward, _clamp
from rl.linucb import LinUCB
from rl.experience import record_episode, get_metrics, reset_experience
from rl.repair_strategies import (
    RepairContext,
    get_repair_system_suffix,
    build_repair_user_message,
)

# ─── Singleton State ─────────────────────────────────────────────

_bandit: Optional[LinUCB] = None


class _EpisodeContext:
    def __init__(self, question: str) -> None:
        self.question = question
        self.steps: list[EpisodeStep] = []
        self.previous_error_class: Optional[ErrorClass] = None
        self.consecutive_same_error: int = 0
        self.last_action: Optional[RepairAction] = None
        self.current_state: Optional[RLState] = None
        self.current_features: Optional[list[float]] = None


_current_episode: Optional[_EpisodeContext] = None


def _get_bandit() -> LinUCB:
    global _bandit
    if _bandit is None:
        _bandit = LinUCB()
    return _bandit


# ─── Environment Interface ────────────────────────────────────────

def reset(question: str) -> None:
    """Start a new episode. If a previous episode was active, end it as failure."""
    global _current_episode
    if _current_episode and _current_episode.steps:
        end_episode(False)
    _current_episode = _EpisodeContext(question)


def observe_error(
    error_message: str,
    failing_sql: str,
    attempt_number: int,
) -> dict:
    """
    Classify the SQL execution error and build the RL state.
    Returns a dict with keys: error_class, error_class_name, state.
    """
    if _current_episode is None:
        raise RuntimeError("Call reset() before observe_error()")

    error_class = classify_error(error_message)
    error_changed = (
        _current_episode.previous_error_class is not None
        and _current_episode.previous_error_class != error_class
    )

    if _current_episode.previous_error_class == error_class:
        _current_episode.consecutive_same_error += 1
    else:
        _current_episode.consecutive_same_error = 1

    state = RLState(
        error_class=error_class,
        attempt_number=attempt_number,
        previous_action=_current_episode.last_action,
        error_changed=error_changed,
        consecutive_same_error=_current_episode.consecutive_same_error,
    )

    _current_episode.current_state = state
    _current_episode.current_features = featurize(state)

    return {
        "error_class": error_class,
        "error_class_name": ERROR_CLASS_NAMES[error_class],
        "state": state,
    }


def select_action() -> dict:
    """
    Ask the bandit to select a repair action based on current state.
    Returns dict with keys: action, action_name, scores.
    """
    if _current_episode is None or _current_episode.current_features is None:
        raise RuntimeError("Call observe_error() before select_action()")

    b = _get_bandit()
    action, scores = b.select_action(_current_episode.current_features)
    _current_episode.last_action = action

    return {
        "action": action,
        "action_name": REPAIR_ACTION_NAMES[action],
        "scores": scores,
    }


def get_repair_prompt(
    action: RepairAction,
    schema: str,
    question: str,
    failing_sql: str,
    error_message: str,
) -> dict:
    """
    Build the system suffix and user message for the chosen repair action.
    Returns dict with keys: system_suffix, user_message.
    """
    offending_token = extract_offending_token(error_message)
    ctx = RepairContext(
        schema=schema,
        question=question,
        failing_sql=failing_sql,
        error_message=error_message,
        offending_token=offending_token,
    )
    return {
        "system_suffix": get_repair_system_suffix(action),
        "user_message": build_repair_user_message(action, ctx),
    }


def record_step(
    action: RepairAction,
    success: bool,
    error_message: str,
    sql: str,
) -> dict:
    """
    Record the outcome of a repair step and compute shaped reward.
    Returns dict with keys: reward, breakdown.
    """
    if _current_episode is None or _current_episode.current_state is None:
        raise RuntimeError("Call observe_error() before record_step()")

    state = _current_episode.current_state

    grader_input = GraderInput(
        success=success,
        attempt_number=state.attempt_number,
        current_error_class=None if success else classify_error(error_message),
        previous_error_class=_current_episode.previous_error_class,
    )
    result = compute_reward(grader_input)

    step = EpisodeStep(
        state=state,
        featurized=_current_episode.current_features or featurize(state),
        action=action,
        reward=result.reward,
        error_message=error_message,
        sql=sql,
        success=success,
    )

    _current_episode.steps.append(step)
    _current_episode.previous_error_class = state.error_class

    return {
        "reward": _clamp(result.reward),
        "breakdown": {
            "base": result.breakdown.base,
            "attempt_penalty": result.breakdown.attempt_penalty,
            "severity_bonus": result.breakdown.severity_bonus,
            "change_bonus": result.breakdown.change_bonus,
        },
    }


def end_episode(success: bool) -> Optional[dict]:
    """
    End the current episode. Runs HER relabeling and updates the bandit.
    Returns dict with keys: total_reward, episode_length.
    """
    global _current_episode
    if _current_episode is None or not _current_episode.steps:
        _current_episode = None
        return None

    b = _get_bandit()
    episode, relabeled = record_episode(
        _current_episode.question,
        _current_episode.steps,
        success,
    )

    for exp in relabeled:
        b.update(exp.state, exp.action, exp.reward)

    b.decay_alpha()

    result = {
        "total_reward": _clamp(episode.total_reward),
        "episode_length": len(episode.steps),
    }

    _current_episode = None
    return result


# ─── Query Interface ──────────────────────────────────────────────

def get_rl_metrics() -> RLMetrics:
    return get_metrics()


def get_bandit_state() -> dict:
    b = _get_bandit()
    return {
        "action_counts": b.get_action_counts(),
        "total_updates": b.get_total_updates(),
        "alpha": b.get_alpha(),
        "action_distribution": b.get_action_distribution(),
    }


def is_episode_active() -> bool:
    return _current_episode is not None


def reset_rl() -> None:
    """Reset the entire RL system β€” bandit weights and experience store."""
    global _bandit, _current_episode
    if _bandit:
        _bandit.reset()
    reset_experience()
    _current_episode = None