File size: 2,818 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Specialist Memory — records (task, output, reward) tuples per specialist.
Persisted to JSON so memory survives training restarts.
Used by SpecialistFinetuner to evolve specialist system prompts.
"""

from __future__ import annotations
import json
from dataclasses import dataclass, asdict
from pathlib import Path


@dataclass
class MemoryEntry:
    specialist_id: str
    task: str
    output: str
    reward: float


class SpecialistMemory:
    """
    Per-specialist replay buffer of (task, output, reward) tuples.
    Capped at MAX_PER_SPECIALIST entries; excess low-reward entries are dropped.
    """

    MAX_PER_SPECIALIST = 50

    def __init__(self, path: str = "data/specialist_memory.json"):
        self._path = Path(path)
        self._entries: dict[str, list[MemoryEntry]] = {}
        if self._path.exists():
            self._load()

    def record(
        self,
        specialist_id: str,
        task: str,
        output: str,
        reward: float,
    ) -> None:
        entries = self._entries.setdefault(specialist_id, [])
        entries.append(MemoryEntry(specialist_id, task[:500], output[:800], float(reward)))
        if len(entries) > self.MAX_PER_SPECIALIST:
            entries.sort(key=lambda e: e.reward, reverse=True)
            self._entries[specialist_id] = entries[: self.MAX_PER_SPECIALIST]

    def get_top_examples(self, specialist_id: str, n: int = 5) -> list[MemoryEntry]:
        entries = self._entries.get(specialist_id, [])
        return sorted(entries, key=lambda e: e.reward, reverse=True)[:n]

    def get_failure_examples(self, specialist_id: str, n: int = 3) -> list[MemoryEntry]:
        entries = self._entries.get(specialist_id, [])
        return sorted(entries, key=lambda e: e.reward)[:n]

    def count(self, specialist_id: str) -> int:
        return len(self._entries.get(specialist_id, []))

    def avg_reward(self, specialist_id: str) -> float:
        entries = self._entries.get(specialist_id, [])
        if not entries:
            return 0.0
        return sum(e.reward for e in entries) / len(entries)

    def all_specialist_ids(self) -> list[str]:
        return list(self._entries.keys())

    def save(self) -> None:
        self._path.parent.mkdir(parents=True, exist_ok=True)
        data = {
            sid: [asdict(e) for e in entries]
            for sid, entries in self._entries.items()
        }
        with open(self._path, "w") as f:
            json.dump(data, f, indent=2)

    def _load(self) -> None:
        try:
            with open(self._path) as f:
                data = json.load(f)
            for sid, entries in data.items():
                self._entries[sid] = [MemoryEntry(**e) for e in entries]
        except Exception as exc:
            print(f"[SpecialistMemory] Could not load {self._path}: {exc}")