Team404_FORGE / memory.py
sanjay7676's picture
Final Submission Upgrade: Advanced tier progression, professional README, and hackathon blog
bb6d47c
"""Coach memory storage and summarization for FORGE-v4."""
from __future__ import annotations
import json
import os
import uuid
from datetime import datetime
from typing import Any
from config import MEMORY_FILE, MEMORY_MAX_LESSONS
class CoachMemory:
"""
Persistent memory that accumulates lessons learned across training episodes.
Lessons are stored as a list of dicts in a JSON file and loaded on startup.
Each lesson includes a human-readable "coach_note" derived from the metrics
so the history is understandable without post-processing.
"""
def __init__(self, filepath: str = MEMORY_FILE):
self.filepath = filepath
self.lessons: list[dict[str, Any]] = []
self._ensure_data_dir()
self.load()
# ──────────────────────────────────────────────
# Public API
# ──────────────────────────────────────────────
def add_lesson(
self,
episode: int,
agent: str,
observation: str,
coder_reward: float,
breaker_reward: float,
extra: dict[str, Any] | None = None,
) -> None:
"""
Record a lesson from one episode step.
Args:
episode: Episode index.
agent: "coder" | "breaker" | "env".
observation: Raw observation string from the environment.
coder_reward: Total coder reward for this step.
breaker_reward: Total breaker reward for this step.
extra: Optional metadata (pass_rate, fail_count, etc.).
"""
coach_note = self._derive_coach_note(episode, extra or {})
reward_delta = round(coder_reward - breaker_reward, 4)
reward_weight = round(abs(coder_reward) + abs(breaker_reward), 4)
lesson = {
"timestamp": datetime.utcnow().isoformat(),
"episode": episode,
"agent": agent,
"observation": observation,
"coach_note": coach_note,
"coder_reward": coder_reward,
"breaker_reward": breaker_reward,
"reward_delta": reward_delta,
"reward_weight": reward_weight,
}
if extra:
lesson["extra"] = extra
self.lessons.append(lesson)
if len(self.lessons) > MEMORY_MAX_LESSONS:
self.lessons = self.lessons[-MEMORY_MAX_LESSONS:]
self.save()
def get_lessons(
self,
agent: str | None = None,
last_n: int | None = None,
) -> list[dict[str, Any]]:
"""
Retrieve stored lessons, optionally filtered by agent and/or limited to the last N.
"""
result = self.lessons
if agent is not None:
result = [l for l in result if l.get("agent") == agent]
if last_n is not None:
result = result[-last_n:]
return result
def get_coach_notes(self, last_n: int = 5) -> list[str]:
"""Return the most recent human-readable coach notes."""
return [l["coach_note"] for l in self.lessons[-last_n:] if l.get("coach_note")]
def summary(self) -> dict[str, Any]:
"""Return a high-level summary of stored lessons."""
if not self.lessons:
return {
"total_lessons": 0,
"episodes_seen": 0,
"weighted_signal": 0.0,
"top_lessons": [],
}
episodes = {l["episode"] for l in self.lessons}
coder_rewards = [l["coder_reward"] for l in self.lessons]
breaker_rewards = [l["breaker_reward"] for l in self.lessons]
weighted_signal = sum((l.get("reward_delta", 0.0)) * max(1.0, l.get("reward_weight", 0.0)) for l in self.lessons)
top_lessons = sorted(self.lessons, key=lambda item: item.get("reward_weight", 0.0), reverse=True)[:3]
return {
"total_lessons": len(self.lessons),
"episodes_seen": len(episodes),
"avg_coder_reward": round(sum(coder_rewards) / len(coder_rewards), 4),
"avg_breaker_reward": round(sum(breaker_rewards) / len(breaker_rewards), 4),
"weighted_signal": round(weighted_signal, 4),
"recent_coach_notes": self.get_coach_notes(last_n=3),
"top_lessons": [
{
"episode": lesson.get("episode"),
"coach_note": lesson.get("coach_note", ""),
"reward_weight": lesson.get("reward_weight", 0.0),
}
for lesson in top_lessons
],
}
def clear(self) -> None:
"""Wipe all stored lessons (use with caution)."""
self.lessons = []
self.save()
# ──────────────────────────────────────────────
# Persistence helpers
# ──────────────────────────────────────────────
def save(self) -> None:
"""Persist lessons to JSON file."""
try:
with open(self.filepath, "w", encoding="utf-8") as f:
json.dump(self.lessons, f, indent=2)
except (IOError, PermissionError):
# Fallback for Windows file locks
tmp_path = f"{self.filepath}.{uuid.uuid4().hex[:6]}.tmp"
try:
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump(self.lessons, f, indent=2)
import os
if os.path.exists(self.filepath):
os.remove(self.filepath)
os.rename(tmp_path, self.filepath)
except:
pass # Silently fail if even fallback fails to avoid crashing training
def load(self) -> None:
"""Load lessons from JSON file if it exists."""
if os.path.exists(self.filepath):
try:
with open(self.filepath, "r", encoding="utf-8") as f:
loaded = json.load(f)
self.lessons = loaded if isinstance(loaded, list) else []
except (json.JSONDecodeError, IOError):
self.lessons = []
else:
self.lessons = []
# ──────────────────────────────────────────────
# Coach note derivation
# ──────────────────────────────────────────────
def _derive_coach_note(self, episode: int, extra: dict[str, Any]) -> str:
"""
Generate a human-readable coaching note from step metadata.
Examples:
"Episode 4: Coder failed on duplicates β†’ handle duplicate values safely"
"Episode 8: Coder timed out on large arrays β†’ avoid O(nΒ²) for large inputs"
"Episode 2: Strong performance (pass_rate=1.00) β†’ keep current strategy"
"""
pass_rate = extra.get("coder_pass_rate", None)
fail_count = extra.get("fail_count", 0)
error_count = extra.get("error_count", 0)
timeout_count = extra.get("timeout_count", 0)
breaker_tier = extra.get("breaker_tier", 1)
coder_version = extra.get("coder_version", "unknown")
recent_case = extra.get("recent_breaker_case", [])
prefix = f"Episode {episode}"
# Timeout note
if timeout_count > 0:
return (
f"{prefix}: Coder timed out on {timeout_count} test(s)"
f" [tier={breaker_tier}] β†’ avoid O(nΒ²) or infinite loops for large inputs"
)
# Error note
if error_count > 0 and pass_rate is not None and pass_rate < 0.5:
return (
f"{prefix}: Coder raised errors on {error_count} test(s)"
f" β†’ add input validation and handle edge cases"
)
# Negative/duplicate failure detection from recent breaker case
if fail_count > 0 and recent_case:
has_neg = any(x < 0 for x in recent_case)
has_dups = len(recent_case) != len(set(recent_case))
is_large = len(recent_case) >= 10
if has_neg and has_dups:
return (
f"{prefix}: Coder ({coder_version}) failed on negatives+duplicates"
f" β†’ ensure sort key uses true value, not abs()"
)
if has_neg:
return (
f"{prefix}: Coder ({coder_version}) failed on negative values"
f" β†’ handle negative integers in comparison logic"
)
if has_dups:
return (
f"{prefix}: Coder ({coder_version}) failed on duplicate values"
f" β†’ ensure stable sort handles equal elements correctly"
)
if is_large:
return (
f"{prefix}: Coder ({coder_version}) failed on large array (n={len(recent_case)})"
f" β†’ consider O(n log n) algorithm"
)
return (
f"{prefix}: Coder ({coder_version}) failed {fail_count} test(s)"
f" at breaker {breaker_tier} β†’ review edge case handling"
)
# Good performance
if pass_rate is not None and pass_rate >= 0.8:
return (
f"{prefix}: Strong performance (pass_rate={pass_rate:.2f})"
f" [{coder_version}] β†’ breaker should escalate tier"
)
# Generic fallback
pr = f"{pass_rate:.2f}" if pass_rate is not None else "N/A"
return f"{prefix}: pass_rate={pr}, fail={fail_count}, errors={error_count}"
# ──────────────────────────────────────────────
# Internal helpers
# ──────────────────────────────────────────────
def _ensure_data_dir(self) -> None:
"""Create the directory for the memory file if it doesn't exist."""
directory = os.path.dirname(self.filepath)
if directory:
os.makedirs(directory, exist_ok=True)