""" SQLDebugEnvironment — Gym-like RL environment for the SQL debug loop. Lifecycle: 1. env.reset(question) — start new episode 2. env.observe_error(error, sql) — classify error, build state 3. env.select_action() — bandit picks repair strategy 4. env.get_repair_prompt(...) — get specialized prompt for chosen action 5. env.record_step(success) — record outcome, compute reward 6. Repeat 2-5 until success or max attempts 7. env.end_episode(success) — finalize, HER relabeling, bandit update This module is a stateful singleton — one active episode at a time. """ from __future__ import annotations import time from typing import Optional from rl.types import ( RLState, RepairAction, ErrorClass, EpisodeStep, RLMetrics, featurize, REPAIR_ACTION_NAMES, ERROR_CLASS_NAMES, ) from rl.error_classifier import classify_error, extract_offending_token from rl.grader import GraderInput, compute_reward, _clamp from rl.linucb import LinUCB from rl.experience import record_episode, get_metrics, reset_experience from rl.repair_strategies import ( RepairContext, get_repair_system_suffix, build_repair_user_message, ) # ─── Singleton State ───────────────────────────────────────────── _bandit: Optional[LinUCB] = None class _EpisodeContext: def __init__(self, question: str) -> None: self.question = question self.steps: list[EpisodeStep] = [] self.previous_error_class: Optional[ErrorClass] = None self.consecutive_same_error: int = 0 self.last_action: Optional[RepairAction] = None self.current_state: Optional[RLState] = None self.current_features: Optional[list[float]] = None _current_episode: Optional[_EpisodeContext] = None def _get_bandit() -> LinUCB: global _bandit if _bandit is None: _bandit = LinUCB() return _bandit # ─── Environment Interface ──────────────────────────────────────── def reset(question: str) -> None: """Start a new episode. If a previous episode was active, end it as failure.""" global _current_episode if _current_episode and _current_episode.steps: end_episode(False) _current_episode = _EpisodeContext(question) def observe_error( error_message: str, failing_sql: str, attempt_number: int, ) -> dict: """ Classify the SQL execution error and build the RL state. Returns a dict with keys: error_class, error_class_name, state. """ if _current_episode is None: raise RuntimeError("Call reset() before observe_error()") error_class = classify_error(error_message) error_changed = ( _current_episode.previous_error_class is not None and _current_episode.previous_error_class != error_class ) if _current_episode.previous_error_class == error_class: _current_episode.consecutive_same_error += 1 else: _current_episode.consecutive_same_error = 1 state = RLState( error_class=error_class, attempt_number=attempt_number, previous_action=_current_episode.last_action, error_changed=error_changed, consecutive_same_error=_current_episode.consecutive_same_error, ) _current_episode.current_state = state _current_episode.current_features = featurize(state) return { "error_class": error_class, "error_class_name": ERROR_CLASS_NAMES[error_class], "state": state, } def select_action() -> dict: """ Ask the bandit to select a repair action based on current state. Returns dict with keys: action, action_name, scores. """ if _current_episode is None or _current_episode.current_features is None: raise RuntimeError("Call observe_error() before select_action()") b = _get_bandit() action, scores = b.select_action(_current_episode.current_features) _current_episode.last_action = action return { "action": action, "action_name": REPAIR_ACTION_NAMES[action], "scores": scores, } def get_repair_prompt( action: RepairAction, schema: str, question: str, failing_sql: str, error_message: str, ) -> dict: """ Build the system suffix and user message for the chosen repair action. Returns dict with keys: system_suffix, user_message. """ offending_token = extract_offending_token(error_message) ctx = RepairContext( schema=schema, question=question, failing_sql=failing_sql, error_message=error_message, offending_token=offending_token, ) return { "system_suffix": get_repair_system_suffix(action), "user_message": build_repair_user_message(action, ctx), } def record_step( action: RepairAction, success: bool, error_message: str, sql: str, ) -> dict: """ Record the outcome of a repair step and compute shaped reward. Returns dict with keys: reward, breakdown. """ if _current_episode is None or _current_episode.current_state is None: raise RuntimeError("Call observe_error() before record_step()") state = _current_episode.current_state grader_input = GraderInput( success=success, attempt_number=state.attempt_number, current_error_class=None if success else classify_error(error_message), previous_error_class=_current_episode.previous_error_class, ) result = compute_reward(grader_input) step = EpisodeStep( state=state, featurized=_current_episode.current_features or featurize(state), action=action, reward=result.reward, error_message=error_message, sql=sql, success=success, ) _current_episode.steps.append(step) _current_episode.previous_error_class = state.error_class return { "reward": _clamp(result.reward), "breakdown": { "base": result.breakdown.base, "attempt_penalty": result.breakdown.attempt_penalty, "severity_bonus": result.breakdown.severity_bonus, "change_bonus": result.breakdown.change_bonus, }, } def end_episode(success: bool) -> Optional[dict]: """ End the current episode. Runs HER relabeling and updates the bandit. Returns dict with keys: total_reward, episode_length. """ global _current_episode if _current_episode is None or not _current_episode.steps: _current_episode = None return None b = _get_bandit() episode, relabeled = record_episode( _current_episode.question, _current_episode.steps, success, ) for exp in relabeled: b.update(exp.state, exp.action, exp.reward) b.decay_alpha() result = { "total_reward": _clamp(episode.total_reward), "episode_length": len(episode.steps), } _current_episode = None return result # ─── Query Interface ────────────────────────────────────────────── def get_rl_metrics() -> RLMetrics: return get_metrics() def get_bandit_state() -> dict: b = _get_bandit() return { "action_counts": b.get_action_counts(), "total_updates": b.get_total_updates(), "alpha": b.get_alpha(), "action_distribution": b.get_action_distribution(), } def is_episode_active() -> bool: return _current_episode is not None def reset_rl() -> None: """Reset the entire RL system — bandit weights and experience store.""" global _bandit, _current_episode if _bandit: _bandit.reset() reset_experience() _current_episode = None