File size: 5,751 Bytes
3c665d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Experience store: logs episodes, persists to disk, and implements
Hindsight Experience Replay (HER) for reward relabeling.

HER (Andrychowicz et al., 2017): If a later attempt in the same episode
succeeded, relabel earlier failed steps with partial credit proportional
to their distance from the success step. This multiplies the effective
training signal from sparse rewards.
"""

from __future__ import annotations

import json
import os
import time
import random
from pathlib import Path
from typing import Optional

from rl.types import (
    Episode,
    EpisodeStep,
    Experience,
    RLMetrics,
    RepairAction,
    REPAIR_ACTION_NAMES,
    ERROR_CLASS_NAMES,
)
from rl.grader import compute_episode_reward

_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
EXPERIENCE_PATH = _DATA_DIR / "rl_experiences.json"
MAX_EPISODES = 500

_episodes: list[Episode] = []
_loaded: bool = False


def _ensure_loaded() -> None:
    global _loaded, _episodes
    if _loaded:
        return
    _loaded = True
    try:
        if EXPERIENCE_PATH.exists():
            raw = json.loads(EXPERIENCE_PATH.read_text())
            _episodes = [Episode(**ep) for ep in raw]
    except Exception:
        _episodes = []


def _persist() -> None:
    try:
        EXPERIENCE_PATH.parent.mkdir(parents=True, exist_ok=True)
        data = [ep.model_dump() for ep in _episodes[-MAX_EPISODES:]]
        EXPERIENCE_PATH.write_text(json.dumps(data, default=str))
    except Exception:
        pass


def record_episode(
    question: str,
    steps: list[EpisodeStep],
    success: bool,
) -> tuple[Episode, list[Experience]]:
    """
    Record a completed episode, run HER relabeling, and persist.
    Returns (episode, relabeled_experiences).
    """
    _ensure_loaded()

    step_rewards = [s.reward for s in steps]
    total_reward = compute_episode_reward(step_rewards, success)

    episode = Episode(
        id=f"ep-{int(time.time() * 1000)}-{random.randint(1000, 9999)}",
        question=question,
        steps=steps,
        total_reward=total_reward,
        success=success,
        timestamp=time.time(),
    )

    _episodes.append(episode)
    if len(_episodes) > MAX_EPISODES:
        _episodes[:] = _episodes[-MAX_EPISODES:]
    _persist()

    relabeled = _apply_her(episode)
    return episode, relabeled


def _apply_her(episode: Episode) -> list[Experience]:
    """
    Hindsight Experience Replay.

    If the episode eventually succeeded at step T, relabel earlier
    failed steps with a hindsight bonus:
      bonus(t) = 0.3 * (1 - (T - t) / T)

    Steps closer to the eventual success receive more credit.
    """
    experiences: list[Experience] = []
    success_step_idx = next(
        (i for i, s in enumerate(episode.steps) if s.success), -1
    )

    for t, step in enumerate(episode.steps):
        reward = step.reward

        if success_step_idx > t:
            distance = success_step_idx - t
            total_steps = len(episode.steps)
            her_bonus = 0.3 * (1.0 - distance / total_steps)
            reward += her_bonus

        next_step = episode.steps[t + 1] if t < len(episode.steps) - 1 else None

        experiences.append(
            Experience(
                state=step.featurized,
                action=step.action,
                reward=reward,
                next_state=next_step.featurized if next_step else None,
                done=(t == len(episode.steps) - 1),
                timestamp=episode.timestamp,
                metadata={
                    "question": episode.question,
                    "error_message": step.error_message,
                    "sql": step.sql,
                    "error_class": int(step.state.error_class),
                    "attempt_number": step.state.attempt_number,
                },
            )
        )

    return experiences


def replay_all(bandit) -> int:
    """
    Replay all stored experiences through the bandit to rebuild weights.
    Useful after a reset or if weights are lost.
    """
    _ensure_loaded()
    count = 0
    for ep in _episodes:
        relabeled = _apply_her(ep)
        for exp in relabeled:
            bandit.update(exp.state, exp.action, exp.reward)
            count += 1
    return count


def get_metrics() -> RLMetrics:
    _ensure_loaded()

    recent_window = 50
    recent = _episodes[-recent_window:]
    all_steps = [s for ep in _episodes for s in ep.steps]

    action_dist: dict[str, int] = {}
    error_dist: dict[str, int] = {}

    for step in all_steps:
        a_name = REPAIR_ACTION_NAMES[step.action]
        action_dist[a_name] = action_dist.get(a_name, 0) + 1
        e_name = ERROR_CLASS_NAMES[step.state.error_class]
        error_dist[e_name] = error_dist.get(e_name, 0) + 1

    return RLMetrics(
        total_episodes=len(_episodes),
        total_steps=len(all_steps),
        cumulative_reward=sum(ep.total_reward for ep in _episodes),
        success_rate=(
            sum(1 for ep in recent if ep.success) / len(recent)
            if recent
            else 0.0
        ),
        avg_attempts=(
            sum(len(ep.steps) for ep in recent) / len(recent)
            if recent
            else 0.0
        ),
        action_distribution=action_dist,
        error_distribution=error_dist,
        reward_history=[ep.total_reward for ep in _episodes],
    )


def get_episodes() -> list[Episode]:
    _ensure_loaded()
    return list(_episodes)


def get_recent_episodes(n: int) -> list[Episode]:
    _ensure_loaded()
    return _episodes[-n:]


def reset_experience() -> None:
    global _episodes, _loaded
    _episodes = []
    _loaded = True
    try:
        EXPERIENCE_PATH.unlink(missing_ok=True)
    except Exception:
        pass