File size: 10,654 Bytes
ad55ab7
 
 
fc01d79
 
 
bb6d47c
fc01d79
 
ad55ab7
 
fc01d79
 
 
 
 
 
 
656b5db
 
fc01d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656b5db
fc01d79
 
656b5db
fc01d79
656b5db
 
ad55ab7
 
 
fc01d79
 
 
 
 
656b5db
fc01d79
 
ad55ab7
 
fc01d79
 
 
 
 
ad55ab7
 
fc01d79
 
656b5db
 
 
 
 
fc01d79
 
 
 
 
 
 
 
 
 
656b5db
 
 
 
fc01d79
656b5db
fc01d79
ad55ab7
 
 
 
 
 
fc01d79
 
656b5db
fc01d79
 
ad55ab7
 
 
fc01d79
 
 
656b5db
fc01d79
ad55ab7
656b5db
ad55ab7
 
 
 
 
 
 
 
fc01d79
 
 
656b5db
fc01d79
 
 
 
 
 
 
 
 
bb6d47c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc01d79
 
 
 
 
 
ad55ab7
 
fc01d79
 
 
 
 
656b5db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc01d79
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""Coach memory storage and summarization for FORGE-v4."""

from __future__ import annotations

import json
import os
import uuid
from datetime import datetime
from typing import Any

from config import MEMORY_FILE, MEMORY_MAX_LESSONS


class CoachMemory:
    """
    Persistent memory that accumulates lessons learned across training episodes.

    Lessons are stored as a list of dicts in a JSON file and loaded on startup.
    Each lesson includes a human-readable "coach_note" derived from the metrics
    so the history is understandable without post-processing.
    """

    def __init__(self, filepath: str = MEMORY_FILE):
        self.filepath = filepath
        self.lessons: list[dict[str, Any]] = []
        self._ensure_data_dir()
        self.load()

    # ──────────────────────────────────────────────
    # Public API
    # ──────────────────────────────────────────────

    def add_lesson(
        self,
        episode: int,
        agent: str,
        observation: str,
        coder_reward: float,
        breaker_reward: float,
        extra: dict[str, Any] | None = None,
    ) -> None:
        """
        Record a lesson from one episode step.

        Args:
            episode:        Episode index.
            agent:          "coder" | "breaker" | "env".
            observation:    Raw observation string from the environment.
            coder_reward:   Total coder reward for this step.
            breaker_reward: Total breaker reward for this step.
            extra:          Optional metadata (pass_rate, fail_count, etc.).
        """
        coach_note = self._derive_coach_note(episode, extra or {})

        reward_delta = round(coder_reward - breaker_reward, 4)
        reward_weight = round(abs(coder_reward) + abs(breaker_reward), 4)

        lesson = {
            "timestamp":      datetime.utcnow().isoformat(),
            "episode":        episode,
            "agent":          agent,
            "observation":    observation,
            "coach_note":     coach_note,
            "coder_reward":   coder_reward,
            "breaker_reward": breaker_reward,
            "reward_delta":   reward_delta,
            "reward_weight":  reward_weight,
        }
        if extra:
            lesson["extra"] = extra

        self.lessons.append(lesson)
        if len(self.lessons) > MEMORY_MAX_LESSONS:
            self.lessons = self.lessons[-MEMORY_MAX_LESSONS:]
        self.save()

    def get_lessons(
        self,
        agent: str | None = None,
        last_n: int | None = None,
    ) -> list[dict[str, Any]]:
        """
        Retrieve stored lessons, optionally filtered by agent and/or limited to the last N.
        """
        result = self.lessons
        if agent is not None:
            result = [l for l in result if l.get("agent") == agent]
        if last_n is not None:
            result = result[-last_n:]
        return result

    def get_coach_notes(self, last_n: int = 5) -> list[str]:
        """Return the most recent human-readable coach notes."""
        return [l["coach_note"] for l in self.lessons[-last_n:] if l.get("coach_note")]

    def summary(self) -> dict[str, Any]:
        """Return a high-level summary of stored lessons."""
        if not self.lessons:
            return {
                "total_lessons": 0,
                "episodes_seen": 0,
                "weighted_signal": 0.0,
                "top_lessons": [],
            }

        episodes = {l["episode"] for l in self.lessons}
        coder_rewards   = [l["coder_reward"]   for l in self.lessons]
        breaker_rewards = [l["breaker_reward"] for l in self.lessons]

        weighted_signal = sum((l.get("reward_delta", 0.0)) * max(1.0, l.get("reward_weight", 0.0)) for l in self.lessons)
        top_lessons = sorted(self.lessons, key=lambda item: item.get("reward_weight", 0.0), reverse=True)[:3]

        return {
            "total_lessons":      len(self.lessons),
            "episodes_seen":      len(episodes),
            "avg_coder_reward":   round(sum(coder_rewards)   / len(coder_rewards),   4),
            "avg_breaker_reward": round(sum(breaker_rewards) / len(breaker_rewards), 4),
            "weighted_signal":    round(weighted_signal, 4),
            "recent_coach_notes": self.get_coach_notes(last_n=3),
            "top_lessons": [
                {
                    "episode": lesson.get("episode"),
                    "coach_note": lesson.get("coach_note", ""),
                    "reward_weight": lesson.get("reward_weight", 0.0),
                }
                for lesson in top_lessons
            ],
        }

    def clear(self) -> None:
        """Wipe all stored lessons (use with caution)."""
        self.lessons = []
        self.save()

    # ──────────────────────────────────────────────
    # Persistence helpers
    # ──────────────────────────────────────────────

    def save(self) -> None:
        """Persist lessons to JSON file."""
        try:
            with open(self.filepath, "w", encoding="utf-8") as f:
                json.dump(self.lessons, f, indent=2)
        except (IOError, PermissionError):
            # Fallback for Windows file locks
            tmp_path = f"{self.filepath}.{uuid.uuid4().hex[:6]}.tmp"
            try:
                with open(tmp_path, "w", encoding="utf-8") as f:
                    json.dump(self.lessons, f, indent=2)
                import os
                if os.path.exists(self.filepath):
                    os.remove(self.filepath)
                os.rename(tmp_path, self.filepath)
            except:
                pass # Silently fail if even fallback fails to avoid crashing training

    def load(self) -> None:
        """Load lessons from JSON file if it exists."""
        if os.path.exists(self.filepath):
            try:
                with open(self.filepath, "r", encoding="utf-8") as f:
                    loaded = json.load(f)
                    self.lessons = loaded if isinstance(loaded, list) else []
            except (json.JSONDecodeError, IOError):
                self.lessons = []
        else:
            self.lessons = []

    # ──────────────────────────────────────────────
    # Coach note derivation
    # ──────────────────────────────────────────────

    def _derive_coach_note(self, episode: int, extra: dict[str, Any]) -> str:
        """
        Generate a human-readable coaching note from step metadata.

        Examples:
            "Episode 4: Coder failed on duplicates β†’ handle duplicate values safely"
            "Episode 8: Coder timed out on large arrays β†’ avoid O(nΒ²) for large inputs"
            "Episode 2: Strong performance (pass_rate=1.00) β†’ keep current strategy"
        """
        pass_rate     = extra.get("coder_pass_rate",    None)
        fail_count    = extra.get("fail_count",          0)
        error_count   = extra.get("error_count",         0)
        timeout_count = extra.get("timeout_count",       0)
        breaker_tier  = extra.get("breaker_tier",        1)
        coder_version = extra.get("coder_version",      "unknown")
        recent_case   = extra.get("recent_breaker_case", [])

        prefix = f"Episode {episode}"

        # Timeout note
        if timeout_count > 0:
            return (
                f"{prefix}: Coder timed out on {timeout_count} test(s)"
                f" [tier={breaker_tier}] β†’ avoid O(nΒ²) or infinite loops for large inputs"
            )

        # Error note
        if error_count > 0 and pass_rate is not None and pass_rate < 0.5:
            return (
                f"{prefix}: Coder raised errors on {error_count} test(s)"
                f" β†’ add input validation and handle edge cases"
            )

        # Negative/duplicate failure detection from recent breaker case
        if fail_count > 0 and recent_case:
            has_neg  = any(x < 0 for x in recent_case)
            has_dups = len(recent_case) != len(set(recent_case))
            is_large = len(recent_case) >= 10

            if has_neg and has_dups:
                return (
                    f"{prefix}: Coder ({coder_version}) failed on negatives+duplicates"
                    f" β†’ ensure sort key uses true value, not abs()"
                )
            if has_neg:
                return (
                    f"{prefix}: Coder ({coder_version}) failed on negative values"
                    f" β†’ handle negative integers in comparison logic"
                )
            if has_dups:
                return (
                    f"{prefix}: Coder ({coder_version}) failed on duplicate values"
                    f" β†’ ensure stable sort handles equal elements correctly"
                )
            if is_large:
                return (
                    f"{prefix}: Coder ({coder_version}) failed on large array (n={len(recent_case)})"
                    f" β†’ consider O(n log n) algorithm"
                )
            return (
                f"{prefix}: Coder ({coder_version}) failed {fail_count} test(s)"
                f" at breaker {breaker_tier} β†’ review edge case handling"
            )

        # Good performance
        if pass_rate is not None and pass_rate >= 0.8:
            return (
                f"{prefix}: Strong performance (pass_rate={pass_rate:.2f})"
                f" [{coder_version}] β†’ breaker should escalate tier"
            )

        # Generic fallback
        pr = f"{pass_rate:.2f}" if pass_rate is not None else "N/A"
        return f"{prefix}: pass_rate={pr}, fail={fail_count}, errors={error_count}"

    # ──────────────────────────────────────────────
    # Internal helpers
    # ──────────────────────────────────────────────

    def _ensure_data_dir(self) -> None:
        """Create the directory for the memory file if it doesn't exist."""
        directory = os.path.dirname(self.filepath)
        if directory:
            os.makedirs(directory, exist_ok=True)