apishift-env / server /memory /memory_agent.py
yaswanth169's picture
Initial APIShift env push
3040bf7 verified
"""MemoryAgent — extracts a lesson from a finished episode and edits lessons.md.
Implementation is deterministic and dependency-free for the hackathon
submission. It mines:
- the action history (what dispatches were issued, in what order)
- the per-step rewards (which steps moved the needle)
- the final score and component breakdown
into a structured lesson keyed by the dominant change_type encountered
during the episode.
If a lesson with the same pattern already exists in lessons.md, the
MemoryAgent updates its confidence using a moving average and appends
the new episode_id; otherwise it creates a fresh entry.
"""
import os
from typing import Dict, List, Optional
from server.memory.lesson_schema import (
Lesson,
parse_lessons_md,
render_lessons_md,
)
class MemoryAgent:
def __init__(self, lessons_path: str = "lessons.md", min_score_to_remember: float = 0.55):
self.lessons_path = lessons_path
self.min_score = min_score_to_remember
def update(
self,
episode_id: str,
score: float,
n_steps: int,
action_log: List[Dict],
dominant_change_type: Optional[str],
component_rewards: Dict[str, float],
) -> Optional[Lesson]:
"""Read lessons.md, extract a new lesson if score warrants it,
write the updated file back. Returns the Lesson written (or None)."""
if score < self.min_score:
return None
if not dominant_change_type:
return None
existing = self._load()
pattern_key = self._pattern_label(dominant_change_type)
match: Optional[Lesson] = None
for L in existing:
if L.pattern == pattern_key:
match = L
break
what_works = self._summarize_what_works(action_log, n_steps)
what_fails = self._summarize_what_fails(component_rewards)
if match is None:
new_lesson = Lesson(
lesson_id=f"L{len(existing)+1:03d}",
pattern=pattern_key,
first_seen=episode_id,
what_works=what_works,
what_fails=what_fails,
confidence=min(0.99, max(0.05, score)),
episodes=[episode_id],
)
existing.append(new_lesson)
written = new_lesson
else:
# moving average over confidence; cap episode list at 20 most recent
n = len(match.episodes) + 1
match.confidence = min(0.99, ((match.confidence * (n - 1)) + score) / n)
if episode_id not in match.episodes:
match.episodes.append(episode_id)
match.episodes = match.episodes[-20:]
if not match.what_works:
match.what_works = what_works
if not match.what_fails:
match.what_fails = what_fails
written = match
self._save(existing)
return written
def _load(self) -> List[Lesson]:
if not os.path.exists(self.lessons_path):
return []
with open(self.lessons_path, "r", encoding="utf-8") as f:
text = f.read()
return parse_lessons_md(text)
def _save(self, lessons: List[Lesson]):
text = render_lessons_md(lessons)
with open(self.lessons_path, "w", encoding="utf-8") as f:
f.write(text)
@staticmethod
def _pattern_label(change_type: str) -> str:
return {
"field_renamed": "Field renames in response shape",
"type_narrowed": "Type narrowing in response fields",
"required_field_added": "Required field additions",
"endpoint_removed": "Endpoint removals or moves",
"endpoint_renamed": "Endpoint renames or moves",
"enum_narrowed": "Enum value narrowing",
"response_shape_changed": "Response shape nesting changes",
"auth_scheme_changed": "Authentication scheme changes",
"field_removed": "Field removals in response shape",
"param_required_added": "Required query parameter additions",
"default_changed": "Default value changes",
"method_changed": "HTTP method changes on endpoints",
"status_code_removed": "Status code removals",
}.get(change_type, f"Other: {change_type}")
@staticmethod
def _summarize_what_works(action_log: List[Dict], n_steps: int) -> str:
commands = [a.get("command") for a in action_log if a.get("command")]
if not commands:
return ""
summary = " -> ".join(commands)
return f"In {n_steps} steps: {summary}"
@staticmethod
def _summarize_what_fails(component_rewards: Dict[str, float]) -> str:
weak = [name for name, v in component_rewards.items() if v < 0.4]
if not weak:
return "No weak components observed."
return "Weak components last episode: " + ", ".join(weak)