sql-agent-openenv / backend /rl /environment.py
ar9avg's picture
Nuclear clamp: every reward source in the codebase now returns (0.05, 0.95)
719c147
"""
SQLDebugEnvironment β€” Gym-like RL environment for the SQL debug loop.
Lifecycle:
1. env.reset(question) β€” start new episode
2. env.observe_error(error, sql) β€” classify error, build state
3. env.select_action() β€” bandit picks repair strategy
4. env.get_repair_prompt(...) β€” get specialized prompt for chosen action
5. env.record_step(success) β€” record outcome, compute reward
6. Repeat 2-5 until success or max attempts
7. env.end_episode(success) β€” finalize, HER relabeling, bandit update
This module is a stateful singleton β€” one active episode at a time.
"""
from __future__ import annotations
import time
from typing import Optional
from rl.types import (
RLState,
RepairAction,
ErrorClass,
EpisodeStep,
RLMetrics,
featurize,
REPAIR_ACTION_NAMES,
ERROR_CLASS_NAMES,
)
from rl.error_classifier import classify_error, extract_offending_token
from rl.grader import GraderInput, compute_reward, _clamp
from rl.linucb import LinUCB
from rl.experience import record_episode, get_metrics, reset_experience
from rl.repair_strategies import (
RepairContext,
get_repair_system_suffix,
build_repair_user_message,
)
# ─── Singleton State ─────────────────────────────────────────────
_bandit: Optional[LinUCB] = None
class _EpisodeContext:
def __init__(self, question: str) -> None:
self.question = question
self.steps: list[EpisodeStep] = []
self.previous_error_class: Optional[ErrorClass] = None
self.consecutive_same_error: int = 0
self.last_action: Optional[RepairAction] = None
self.current_state: Optional[RLState] = None
self.current_features: Optional[list[float]] = None
_current_episode: Optional[_EpisodeContext] = None
def _get_bandit() -> LinUCB:
global _bandit
if _bandit is None:
_bandit = LinUCB()
return _bandit
# ─── Environment Interface ────────────────────────────────────────
def reset(question: str) -> None:
"""Start a new episode. If a previous episode was active, end it as failure."""
global _current_episode
if _current_episode and _current_episode.steps:
end_episode(False)
_current_episode = _EpisodeContext(question)
def observe_error(
error_message: str,
failing_sql: str,
attempt_number: int,
) -> dict:
"""
Classify the SQL execution error and build the RL state.
Returns a dict with keys: error_class, error_class_name, state.
"""
if _current_episode is None:
raise RuntimeError("Call reset() before observe_error()")
error_class = classify_error(error_message)
error_changed = (
_current_episode.previous_error_class is not None
and _current_episode.previous_error_class != error_class
)
if _current_episode.previous_error_class == error_class:
_current_episode.consecutive_same_error += 1
else:
_current_episode.consecutive_same_error = 1
state = RLState(
error_class=error_class,
attempt_number=attempt_number,
previous_action=_current_episode.last_action,
error_changed=error_changed,
consecutive_same_error=_current_episode.consecutive_same_error,
)
_current_episode.current_state = state
_current_episode.current_features = featurize(state)
return {
"error_class": error_class,
"error_class_name": ERROR_CLASS_NAMES[error_class],
"state": state,
}
def select_action() -> dict:
"""
Ask the bandit to select a repair action based on current state.
Returns dict with keys: action, action_name, scores.
"""
if _current_episode is None or _current_episode.current_features is None:
raise RuntimeError("Call observe_error() before select_action()")
b = _get_bandit()
action, scores = b.select_action(_current_episode.current_features)
_current_episode.last_action = action
return {
"action": action,
"action_name": REPAIR_ACTION_NAMES[action],
"scores": scores,
}
def get_repair_prompt(
action: RepairAction,
schema: str,
question: str,
failing_sql: str,
error_message: str,
) -> dict:
"""
Build the system suffix and user message for the chosen repair action.
Returns dict with keys: system_suffix, user_message.
"""
offending_token = extract_offending_token(error_message)
ctx = RepairContext(
schema=schema,
question=question,
failing_sql=failing_sql,
error_message=error_message,
offending_token=offending_token,
)
return {
"system_suffix": get_repair_system_suffix(action),
"user_message": build_repair_user_message(action, ctx),
}
def record_step(
action: RepairAction,
success: bool,
error_message: str,
sql: str,
) -> dict:
"""
Record the outcome of a repair step and compute shaped reward.
Returns dict with keys: reward, breakdown.
"""
if _current_episode is None or _current_episode.current_state is None:
raise RuntimeError("Call observe_error() before record_step()")
state = _current_episode.current_state
grader_input = GraderInput(
success=success,
attempt_number=state.attempt_number,
current_error_class=None if success else classify_error(error_message),
previous_error_class=_current_episode.previous_error_class,
)
result = compute_reward(grader_input)
step = EpisodeStep(
state=state,
featurized=_current_episode.current_features or featurize(state),
action=action,
reward=result.reward,
error_message=error_message,
sql=sql,
success=success,
)
_current_episode.steps.append(step)
_current_episode.previous_error_class = state.error_class
return {
"reward": _clamp(result.reward),
"breakdown": {
"base": result.breakdown.base,
"attempt_penalty": result.breakdown.attempt_penalty,
"severity_bonus": result.breakdown.severity_bonus,
"change_bonus": result.breakdown.change_bonus,
},
}
def end_episode(success: bool) -> Optional[dict]:
"""
End the current episode. Runs HER relabeling and updates the bandit.
Returns dict with keys: total_reward, episode_length.
"""
global _current_episode
if _current_episode is None or not _current_episode.steps:
_current_episode = None
return None
b = _get_bandit()
episode, relabeled = record_episode(
_current_episode.question,
_current_episode.steps,
success,
)
for exp in relabeled:
b.update(exp.state, exp.action, exp.reward)
b.decay_alpha()
result = {
"total_reward": _clamp(episode.total_reward),
"episode_length": len(episode.steps),
}
_current_episode = None
return result
# ─── Query Interface ──────────────────────────────────────────────
def get_rl_metrics() -> RLMetrics:
return get_metrics()
def get_bandit_state() -> dict:
b = _get_bandit()
return {
"action_counts": b.get_action_counts(),
"total_updates": b.get_total_updates(),
"alpha": b.get_alpha(),
"action_distribution": b.get_action_distribution(),
}
def is_episode_active() -> bool:
return _current_episode is not None
def reset_rl() -> None:
"""Reset the entire RL system β€” bandit weights and experience store."""
global _bandit, _current_episode
if _bandit:
_bandit.reset()
reset_experience()
_current_episode = None