""" Experience store: logs episodes, persists to disk, and implements Hindsight Experience Replay (HER) for reward relabeling. HER (Andrychowicz et al., 2017): If a later attempt in the same episode succeeded, relabel earlier failed steps with partial credit proportional to their distance from the success step. This multiplies the effective training signal from sparse rewards. """ from __future__ import annotations import json import os import time import random from pathlib import Path from typing import Optional from rl.types import ( Episode, EpisodeStep, Experience, RLMetrics, RepairAction, REPAIR_ACTION_NAMES, ERROR_CLASS_NAMES, ) from rl.grader import compute_episode_reward _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data")) EXPERIENCE_PATH = _DATA_DIR / "rl_experiences.json" MAX_EPISODES = 500 _episodes: list[Episode] = [] _loaded: bool = False def _ensure_loaded() -> None: global _loaded, _episodes if _loaded: return _loaded = True try: if EXPERIENCE_PATH.exists(): raw = json.loads(EXPERIENCE_PATH.read_text()) _episodes = [Episode(**ep) for ep in raw] except Exception: _episodes = [] def _persist() -> None: try: EXPERIENCE_PATH.parent.mkdir(parents=True, exist_ok=True) data = [ep.model_dump() for ep in _episodes[-MAX_EPISODES:]] EXPERIENCE_PATH.write_text(json.dumps(data, default=str)) except Exception: pass def record_episode( question: str, steps: list[EpisodeStep], success: bool, ) -> tuple[Episode, list[Experience]]: """ Record a completed episode, run HER relabeling, and persist. Returns (episode, relabeled_experiences). """ _ensure_loaded() step_rewards = [s.reward for s in steps] total_reward = compute_episode_reward(step_rewards, success) episode = Episode( id=f"ep-{int(time.time() * 1000)}-{random.randint(1000, 9999)}", question=question, steps=steps, total_reward=total_reward, success=success, timestamp=time.time(), ) _episodes.append(episode) if len(_episodes) > MAX_EPISODES: _episodes[:] = _episodes[-MAX_EPISODES:] _persist() relabeled = _apply_her(episode) return episode, relabeled def _apply_her(episode: Episode) -> list[Experience]: """ Hindsight Experience Replay. If the episode eventually succeeded at step T, relabel earlier failed steps with a hindsight bonus: bonus(t) = 0.3 * (1 - (T - t) / T) Steps closer to the eventual success receive more credit. """ experiences: list[Experience] = [] success_step_idx = next( (i for i, s in enumerate(episode.steps) if s.success), -1 ) for t, step in enumerate(episode.steps): reward = step.reward if success_step_idx > t: distance = success_step_idx - t total_steps = len(episode.steps) her_bonus = 0.3 * (1.0 - distance / total_steps) reward += her_bonus next_step = episode.steps[t + 1] if t < len(episode.steps) - 1 else None experiences.append( Experience( state=step.featurized, action=step.action, reward=reward, next_state=next_step.featurized if next_step else None, done=(t == len(episode.steps) - 1), timestamp=episode.timestamp, metadata={ "question": episode.question, "error_message": step.error_message, "sql": step.sql, "error_class": int(step.state.error_class), "attempt_number": step.state.attempt_number, }, ) ) return experiences def replay_all(bandit) -> int: """ Replay all stored experiences through the bandit to rebuild weights. Useful after a reset or if weights are lost. """ _ensure_loaded() count = 0 for ep in _episodes: relabeled = _apply_her(ep) for exp in relabeled: bandit.update(exp.state, exp.action, exp.reward) count += 1 return count def get_metrics() -> RLMetrics: _ensure_loaded() recent_window = 50 recent = _episodes[-recent_window:] all_steps = [s for ep in _episodes for s in ep.steps] action_dist: dict[str, int] = {} error_dist: dict[str, int] = {} for step in all_steps: a_name = REPAIR_ACTION_NAMES[step.action] action_dist[a_name] = action_dist.get(a_name, 0) + 1 e_name = ERROR_CLASS_NAMES[step.state.error_class] error_dist[e_name] = error_dist.get(e_name, 0) + 1 return RLMetrics( total_episodes=len(_episodes), total_steps=len(all_steps), cumulative_reward=sum(ep.total_reward for ep in _episodes), success_rate=( sum(1 for ep in recent if ep.success) / len(recent) if recent else 0.0 ), avg_attempts=( sum(len(ep.steps) for ep in recent) / len(recent) if recent else 0.0 ), action_distribution=action_dist, error_distribution=error_dist, reward_history=[ep.total_reward for ep in _episodes], ) def get_episodes() -> list[Episode]: _ensure_loaded() return list(_episodes) def get_recent_episodes(n: int) -> list[Episode]: _ensure_loaded() return _episodes[-n:] def reset_experience() -> None: global _episodes, _loaded _episodes = [] _loaded = True try: EXPERIENCE_PATH.unlink(missing_ok=True) except Exception: pass