Spaces:
Runtime error
Runtime error
| """ | |
| Specialist Memory — records (task, output, reward) tuples per specialist. | |
| Persisted to JSON so memory survives training restarts. | |
| Used by SpecialistFinetuner to evolve specialist system prompts. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass, asdict | |
| from pathlib import Path | |
| class MemoryEntry: | |
| specialist_id: str | |
| task: str | |
| output: str | |
| reward: float | |
| class SpecialistMemory: | |
| """ | |
| Per-specialist replay buffer of (task, output, reward) tuples. | |
| Capped at MAX_PER_SPECIALIST entries; excess low-reward entries are dropped. | |
| """ | |
| MAX_PER_SPECIALIST = 50 | |
| def __init__(self, path: str = "data/specialist_memory.json"): | |
| self._path = Path(path) | |
| self._entries: dict[str, list[MemoryEntry]] = {} | |
| if self._path.exists(): | |
| self._load() | |
| def record( | |
| self, | |
| specialist_id: str, | |
| task: str, | |
| output: str, | |
| reward: float, | |
| ) -> None: | |
| entries = self._entries.setdefault(specialist_id, []) | |
| entries.append(MemoryEntry(specialist_id, task[:500], output[:800], float(reward))) | |
| if len(entries) > self.MAX_PER_SPECIALIST: | |
| entries.sort(key=lambda e: e.reward, reverse=True) | |
| self._entries[specialist_id] = entries[: self.MAX_PER_SPECIALIST] | |
| def get_top_examples(self, specialist_id: str, n: int = 5) -> list[MemoryEntry]: | |
| entries = self._entries.get(specialist_id, []) | |
| return sorted(entries, key=lambda e: e.reward, reverse=True)[:n] | |
| def get_failure_examples(self, specialist_id: str, n: int = 3) -> list[MemoryEntry]: | |
| entries = self._entries.get(specialist_id, []) | |
| return sorted(entries, key=lambda e: e.reward)[:n] | |
| def count(self, specialist_id: str) -> int: | |
| return len(self._entries.get(specialist_id, [])) | |
| def avg_reward(self, specialist_id: str) -> float: | |
| entries = self._entries.get(specialist_id, []) | |
| if not entries: | |
| return 0.0 | |
| return sum(e.reward for e in entries) / len(entries) | |
| def all_specialist_ids(self) -> list[str]: | |
| return list(self._entries.keys()) | |
| def save(self) -> None: | |
| self._path.parent.mkdir(parents=True, exist_ok=True) | |
| data = { | |
| sid: [asdict(e) for e in entries] | |
| for sid, entries in self._entries.items() | |
| } | |
| with open(self._path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| def _load(self) -> None: | |
| try: | |
| with open(self._path) as f: | |
| data = json.load(f) | |
| for sid, entries in data.items(): | |
| self._entries[sid] = [MemoryEntry(**e) for e in entries] | |
| except Exception as exc: | |
| print(f"[SpecialistMemory] Could not load {self._path}: {exc}") | |