sql-agent-openenv / backend /rl /experience.py
ar9avg's picture
Initial submission: SQL Agent OpenEnv for Meta+HF hackathon
3c665d2
"""
Experience store: logs episodes, persists to disk, and implements
Hindsight Experience Replay (HER) for reward relabeling.
HER (Andrychowicz et al., 2017): If a later attempt in the same episode
succeeded, relabel earlier failed steps with partial credit proportional
to their distance from the success step. This multiplies the effective
training signal from sparse rewards.
"""
from __future__ import annotations
import json
import os
import time
import random
from pathlib import Path
from typing import Optional
from rl.types import (
Episode,
EpisodeStep,
Experience,
RLMetrics,
RepairAction,
REPAIR_ACTION_NAMES,
ERROR_CLASS_NAMES,
)
from rl.grader import compute_episode_reward
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
EXPERIENCE_PATH = _DATA_DIR / "rl_experiences.json"
MAX_EPISODES = 500
_episodes: list[Episode] = []
_loaded: bool = False
def _ensure_loaded() -> None:
global _loaded, _episodes
if _loaded:
return
_loaded = True
try:
if EXPERIENCE_PATH.exists():
raw = json.loads(EXPERIENCE_PATH.read_text())
_episodes = [Episode(**ep) for ep in raw]
except Exception:
_episodes = []
def _persist() -> None:
try:
EXPERIENCE_PATH.parent.mkdir(parents=True, exist_ok=True)
data = [ep.model_dump() for ep in _episodes[-MAX_EPISODES:]]
EXPERIENCE_PATH.write_text(json.dumps(data, default=str))
except Exception:
pass
def record_episode(
question: str,
steps: list[EpisodeStep],
success: bool,
) -> tuple[Episode, list[Experience]]:
"""
Record a completed episode, run HER relabeling, and persist.
Returns (episode, relabeled_experiences).
"""
_ensure_loaded()
step_rewards = [s.reward for s in steps]
total_reward = compute_episode_reward(step_rewards, success)
episode = Episode(
id=f"ep-{int(time.time() * 1000)}-{random.randint(1000, 9999)}",
question=question,
steps=steps,
total_reward=total_reward,
success=success,
timestamp=time.time(),
)
_episodes.append(episode)
if len(_episodes) > MAX_EPISODES:
_episodes[:] = _episodes[-MAX_EPISODES:]
_persist()
relabeled = _apply_her(episode)
return episode, relabeled
def _apply_her(episode: Episode) -> list[Experience]:
"""
Hindsight Experience Replay.
If the episode eventually succeeded at step T, relabel earlier
failed steps with a hindsight bonus:
bonus(t) = 0.3 * (1 - (T - t) / T)
Steps closer to the eventual success receive more credit.
"""
experiences: list[Experience] = []
success_step_idx = next(
(i for i, s in enumerate(episode.steps) if s.success), -1
)
for t, step in enumerate(episode.steps):
reward = step.reward
if success_step_idx > t:
distance = success_step_idx - t
total_steps = len(episode.steps)
her_bonus = 0.3 * (1.0 - distance / total_steps)
reward += her_bonus
next_step = episode.steps[t + 1] if t < len(episode.steps) - 1 else None
experiences.append(
Experience(
state=step.featurized,
action=step.action,
reward=reward,
next_state=next_step.featurized if next_step else None,
done=(t == len(episode.steps) - 1),
timestamp=episode.timestamp,
metadata={
"question": episode.question,
"error_message": step.error_message,
"sql": step.sql,
"error_class": int(step.state.error_class),
"attempt_number": step.state.attempt_number,
},
)
)
return experiences
def replay_all(bandit) -> int:
"""
Replay all stored experiences through the bandit to rebuild weights.
Useful after a reset or if weights are lost.
"""
_ensure_loaded()
count = 0
for ep in _episodes:
relabeled = _apply_her(ep)
for exp in relabeled:
bandit.update(exp.state, exp.action, exp.reward)
count += 1
return count
def get_metrics() -> RLMetrics:
_ensure_loaded()
recent_window = 50
recent = _episodes[-recent_window:]
all_steps = [s for ep in _episodes for s in ep.steps]
action_dist: dict[str, int] = {}
error_dist: dict[str, int] = {}
for step in all_steps:
a_name = REPAIR_ACTION_NAMES[step.action]
action_dist[a_name] = action_dist.get(a_name, 0) + 1
e_name = ERROR_CLASS_NAMES[step.state.error_class]
error_dist[e_name] = error_dist.get(e_name, 0) + 1
return RLMetrics(
total_episodes=len(_episodes),
total_steps=len(all_steps),
cumulative_reward=sum(ep.total_reward for ep in _episodes),
success_rate=(
sum(1 for ep in recent if ep.success) / len(recent)
if recent
else 0.0
),
avg_attempts=(
sum(len(ep.steps) for ep in recent) / len(recent)
if recent
else 0.0
),
action_distribution=action_dist,
error_distribution=error_dist,
reward_history=[ep.total_reward for ep in _episodes],
)
def get_episodes() -> list[Episode]:
_ensure_loaded()
return list(_episodes)
def get_recent_episodes(n: int) -> list[Episode]:
_ensure_loaded()
return _episodes[-n:]
def reset_experience() -> None:
global _episodes, _loaded
_episodes = []
_loaded = True
try:
EXPERIENCE_PATH.unlink(missing_ok=True)
except Exception:
pass