Spaces:
Runtime error
Runtime error
File size: 3,016 Bytes
02ff91f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """
SpawnMemory — tracks which specialist descriptions worked for which tasks.
Used to condition future spawn prompts on past successes.
This is retrieval-augmented generation for specialist design.
Path is configurable via environment.spawn_memory_path in training_config.yaml.
"""
from __future__ import annotations
import json
import numpy as np
from dataclasses import dataclass, asdict
from pathlib import Path
@dataclass
class SpawnRecord:
task_embedding: list[float] # 384-dim stored as list for JSON serialisation
task_description: str
specialist_id: str
specialist_role: str
specialist_desc: str
episode_reward: float # terminal reward of the episode that triggered the spawn
pre_spawn_sim: float
post_spawn_sim: float
episode_idx: int
class SpawnMemory:
"""
File-backed JSONL memory of past spawns with cosine-similarity retrieval.
Capped at max_entries; lowest-reward records are evicted when full.
"""
def __init__(self, path: str, max_entries: int = 500):
self._path = Path(path)
self.max_entries = max_entries
self._path.parent.mkdir(parents=True, exist_ok=True)
self._records: list[SpawnRecord] = self._load()
def _load(self) -> list[SpawnRecord]:
if not self._path.exists():
return []
records = []
for line in self._path.read_text().splitlines():
try:
records.append(SpawnRecord(**json.loads(line)))
except Exception:
continue
return records
def record(self, rec: SpawnRecord) -> None:
self._records.append(rec)
if len(self._records) > self.max_entries:
self._records.sort(key=lambda r: r.episode_reward, reverse=True)
self._records = self._records[: self.max_entries]
with open(self._path, "w") as f:
for r in self._records:
f.write(json.dumps(asdict(r)) + "\n")
def retrieve_similar(
self,
task_embedding: np.ndarray,
top_k: int = 3,
min_reward: float = 0.0,
) -> list[SpawnRecord]:
"""
Return top_k past spawns whose task was most similar to the current
task, filtered to those that produced >= min_reward.
"""
if not self._records:
return []
candidates = [r for r in self._records if r.episode_reward >= min_reward]
if not candidates:
return []
norm_task = task_embedding / (np.linalg.norm(task_embedding) + 1e-8)
scored = []
for rec in candidates:
emb = np.array(rec.task_embedding, dtype=np.float32)
norm_emb = emb / (np.linalg.norm(emb) + 1e-8)
sim = float(np.dot(norm_emb, norm_task))
scored.append((sim, rec))
scored.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in scored[:top_k]]
@property
def size(self) -> int:
return len(self._records)
|