Spaces:
Sleeping
Sleeping
| """ | |
| SQLDebugEnvironment β Gym-like RL environment for the SQL debug loop. | |
| Lifecycle: | |
| 1. env.reset(question) β start new episode | |
| 2. env.observe_error(error, sql) β classify error, build state | |
| 3. env.select_action() β bandit picks repair strategy | |
| 4. env.get_repair_prompt(...) β get specialized prompt for chosen action | |
| 5. env.record_step(success) β record outcome, compute reward | |
| 6. Repeat 2-5 until success or max attempts | |
| 7. env.end_episode(success) β finalize, HER relabeling, bandit update | |
| This module is a stateful singleton β one active episode at a time. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from typing import Optional | |
| from rl.types import ( | |
| RLState, | |
| RepairAction, | |
| ErrorClass, | |
| EpisodeStep, | |
| RLMetrics, | |
| featurize, | |
| REPAIR_ACTION_NAMES, | |
| ERROR_CLASS_NAMES, | |
| ) | |
| from rl.error_classifier import classify_error, extract_offending_token | |
| from rl.grader import GraderInput, compute_reward, _clamp | |
| from rl.linucb import LinUCB | |
| from rl.experience import record_episode, get_metrics, reset_experience | |
| from rl.repair_strategies import ( | |
| RepairContext, | |
| get_repair_system_suffix, | |
| build_repair_user_message, | |
| ) | |
| # βββ Singleton State βββββββββββββββββββββββββββββββββββββββββββββ | |
| _bandit: Optional[LinUCB] = None | |
| class _EpisodeContext: | |
| def __init__(self, question: str) -> None: | |
| self.question = question | |
| self.steps: list[EpisodeStep] = [] | |
| self.previous_error_class: Optional[ErrorClass] = None | |
| self.consecutive_same_error: int = 0 | |
| self.last_action: Optional[RepairAction] = None | |
| self.current_state: Optional[RLState] = None | |
| self.current_features: Optional[list[float]] = None | |
| _current_episode: Optional[_EpisodeContext] = None | |
| def _get_bandit() -> LinUCB: | |
| global _bandit | |
| if _bandit is None: | |
| _bandit = LinUCB() | |
| return _bandit | |
| # βββ Environment Interface ββββββββββββββββββββββββββββββββββββββββ | |
| def reset(question: str) -> None: | |
| """Start a new episode. If a previous episode was active, end it as failure.""" | |
| global _current_episode | |
| if _current_episode and _current_episode.steps: | |
| end_episode(False) | |
| _current_episode = _EpisodeContext(question) | |
| def observe_error( | |
| error_message: str, | |
| failing_sql: str, | |
| attempt_number: int, | |
| ) -> dict: | |
| """ | |
| Classify the SQL execution error and build the RL state. | |
| Returns a dict with keys: error_class, error_class_name, state. | |
| """ | |
| if _current_episode is None: | |
| raise RuntimeError("Call reset() before observe_error()") | |
| error_class = classify_error(error_message) | |
| error_changed = ( | |
| _current_episode.previous_error_class is not None | |
| and _current_episode.previous_error_class != error_class | |
| ) | |
| if _current_episode.previous_error_class == error_class: | |
| _current_episode.consecutive_same_error += 1 | |
| else: | |
| _current_episode.consecutive_same_error = 1 | |
| state = RLState( | |
| error_class=error_class, | |
| attempt_number=attempt_number, | |
| previous_action=_current_episode.last_action, | |
| error_changed=error_changed, | |
| consecutive_same_error=_current_episode.consecutive_same_error, | |
| ) | |
| _current_episode.current_state = state | |
| _current_episode.current_features = featurize(state) | |
| return { | |
| "error_class": error_class, | |
| "error_class_name": ERROR_CLASS_NAMES[error_class], | |
| "state": state, | |
| } | |
| def select_action() -> dict: | |
| """ | |
| Ask the bandit to select a repair action based on current state. | |
| Returns dict with keys: action, action_name, scores. | |
| """ | |
| if _current_episode is None or _current_episode.current_features is None: | |
| raise RuntimeError("Call observe_error() before select_action()") | |
| b = _get_bandit() | |
| action, scores = b.select_action(_current_episode.current_features) | |
| _current_episode.last_action = action | |
| return { | |
| "action": action, | |
| "action_name": REPAIR_ACTION_NAMES[action], | |
| "scores": scores, | |
| } | |
| def get_repair_prompt( | |
| action: RepairAction, | |
| schema: str, | |
| question: str, | |
| failing_sql: str, | |
| error_message: str, | |
| ) -> dict: | |
| """ | |
| Build the system suffix and user message for the chosen repair action. | |
| Returns dict with keys: system_suffix, user_message. | |
| """ | |
| offending_token = extract_offending_token(error_message) | |
| ctx = RepairContext( | |
| schema=schema, | |
| question=question, | |
| failing_sql=failing_sql, | |
| error_message=error_message, | |
| offending_token=offending_token, | |
| ) | |
| return { | |
| "system_suffix": get_repair_system_suffix(action), | |
| "user_message": build_repair_user_message(action, ctx), | |
| } | |
| def record_step( | |
| action: RepairAction, | |
| success: bool, | |
| error_message: str, | |
| sql: str, | |
| ) -> dict: | |
| """ | |
| Record the outcome of a repair step and compute shaped reward. | |
| Returns dict with keys: reward, breakdown. | |
| """ | |
| if _current_episode is None or _current_episode.current_state is None: | |
| raise RuntimeError("Call observe_error() before record_step()") | |
| state = _current_episode.current_state | |
| grader_input = GraderInput( | |
| success=success, | |
| attempt_number=state.attempt_number, | |
| current_error_class=None if success else classify_error(error_message), | |
| previous_error_class=_current_episode.previous_error_class, | |
| ) | |
| result = compute_reward(grader_input) | |
| step = EpisodeStep( | |
| state=state, | |
| featurized=_current_episode.current_features or featurize(state), | |
| action=action, | |
| reward=result.reward, | |
| error_message=error_message, | |
| sql=sql, | |
| success=success, | |
| ) | |
| _current_episode.steps.append(step) | |
| _current_episode.previous_error_class = state.error_class | |
| return { | |
| "reward": _clamp(result.reward), | |
| "breakdown": { | |
| "base": result.breakdown.base, | |
| "attempt_penalty": result.breakdown.attempt_penalty, | |
| "severity_bonus": result.breakdown.severity_bonus, | |
| "change_bonus": result.breakdown.change_bonus, | |
| }, | |
| } | |
| def end_episode(success: bool) -> Optional[dict]: | |
| """ | |
| End the current episode. Runs HER relabeling and updates the bandit. | |
| Returns dict with keys: total_reward, episode_length. | |
| """ | |
| global _current_episode | |
| if _current_episode is None or not _current_episode.steps: | |
| _current_episode = None | |
| return None | |
| b = _get_bandit() | |
| episode, relabeled = record_episode( | |
| _current_episode.question, | |
| _current_episode.steps, | |
| success, | |
| ) | |
| for exp in relabeled: | |
| b.update(exp.state, exp.action, exp.reward) | |
| b.decay_alpha() | |
| result = { | |
| "total_reward": _clamp(episode.total_reward), | |
| "episode_length": len(episode.steps), | |
| } | |
| _current_episode = None | |
| return result | |
| # βββ Query Interface ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_rl_metrics() -> RLMetrics: | |
| return get_metrics() | |
| def get_bandit_state() -> dict: | |
| b = _get_bandit() | |
| return { | |
| "action_counts": b.get_action_counts(), | |
| "total_updates": b.get_total_updates(), | |
| "alpha": b.get_alpha(), | |
| "action_distribution": b.get_action_distribution(), | |
| } | |
| def is_episode_active() -> bool: | |
| return _current_episode is not None | |
| def reset_rl() -> None: | |
| """Reset the entire RL system β bandit weights and experience store.""" | |
| global _bandit, _current_episode | |
| if _bandit: | |
| _bandit.reset() | |
| reset_experience() | |
| _current_episode = None | |