finalRLEnv / tests /test_memory.py
garvitsachdeva's picture
SpindleFlow RL β€” periodic push + log persistence
02ff91f
"""Tests for SpecialistMemory, ResolutionBandit, and SpawnMemory."""
import numpy as np
import pytest
from agents.specialist_memory import SpecialistMemory
from agents.resolution_memory import ResolutionBandit, ResolutionOutcome
from training.spawn_memory import SpawnMemory, SpawnRecord
# ── SpecialistMemory ──────────────────────────────────────────────────────────
def test_specialist_memory_record_and_retrieve(tmp_path):
mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
mem.record("spec_a", "build an API", "Here is the API design.", reward=0.8)
mem.record("spec_a", "write tests", "Here are the tests.", reward=0.5)
assert mem.count("spec_a") == 2
top = mem.get_top_examples("spec_a", n=2)
assert top[0].reward == 0.8
assert top[1].reward == 0.5
def test_specialist_memory_eviction(tmp_path):
mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
mem.MAX_PER_SPECIALIST = 5
for i in range(7):
mem.record("spec_b", f"task {i}", f"output {i}", reward=float(i))
# Lowest-reward entries should be evicted; only 5 remain
assert mem.count("spec_b") == 5
# Remaining entries should all be the 5 highest-reward ones (rewards 2–6)
rewards = {e.reward for e in mem.get_top_examples("spec_b", n=5)}
assert rewards == {2.0, 3.0, 4.0, 5.0, 6.0}
def test_specialist_memory_top_examples_sorted(tmp_path):
mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
for reward in [0.3, 0.9, 0.1, 0.7]:
mem.record("spec_c", "task", "output", reward=reward)
top = mem.get_top_examples("spec_c", n=4)
assert top[0].reward == 0.9
assert top[-1].reward == 0.1
def test_specialist_memory_avg_reward(tmp_path):
mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
mem.record("spec_d", "t", "o", reward=0.4)
mem.record("spec_d", "t", "o", reward=0.6)
assert abs(mem.avg_reward("spec_d") - 0.5) < 1e-6
def test_specialist_memory_empty_specialist(tmp_path):
mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
assert mem.count("nobody") == 0
assert mem.avg_reward("nobody") == 0.0
assert mem.get_top_examples("nobody") == []
# ── ResolutionBandit ──────────────────────────────────────────────────────────
_TEMPLATES = {
"technical": {"standard": "Use {a}.", "defer_to_a": "Defer to {a}."},
"factual": {"recency": "Use recent claim from {a}."},
}
def test_resolution_bandit_returns_valid_key(tmp_path):
bandit = ResolutionBandit(
templates=_TEMPLATES,
config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 1},
memory_path=str(tmp_path / "res.jsonl"),
)
key = bandit.select_template("technical")
assert key in _TEMPLATES["technical"]
def test_resolution_bandit_exploits_best_arm(tmp_path):
bandit = ResolutionBandit(
templates=_TEMPLATES,
config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 2},
memory_path=str(tmp_path / "res.jsonl"),
)
# Seed defer_to_a with high deltas, standard with low
for _ in range(3):
bandit.record_outcome(ResolutionOutcome("technical", "defer_to_a", 0.9, 0))
bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.1, 0))
assert bandit.select_template("technical") == "defer_to_a"
def test_resolution_bandit_random_when_insufficient_samples(tmp_path):
bandit = ResolutionBandit(
templates=_TEMPLATES,
config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 10},
memory_path=str(tmp_path / "res.jsonl"),
)
# Only 2 samples β€” below min_samples of 10, so should still return a valid key
bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.8, 0))
bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.7, 0))
key = bandit.select_template("technical")
assert key in _TEMPLATES["technical"]
def test_resolution_bandit_arm_means(tmp_path):
bandit = ResolutionBandit(
templates=_TEMPLATES,
config={},
memory_path=str(tmp_path / "res.jsonl"),
)
bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.4, 0))
bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.6, 0))
means = bandit.arm_means()
assert abs(means["technical"]["standard"] - 0.5) < 1e-6
def test_resolution_bandit_unknown_type_returns_default(tmp_path):
bandit = ResolutionBandit(
templates=_TEMPLATES,
config={},
memory_path=str(tmp_path / "res.jsonl"),
)
assert bandit.select_template("nonexistent_type") == "default"
# ── SpawnMemory ───────────────────────────────────────────────────────────────
def _make_record(task_emb, reward=0.5, sid="spec_x"):
return SpawnRecord(
task_embedding=task_emb.tolist(),
task_description="test task",
specialist_id=sid,
specialist_role="Test Role",
specialist_desc="A test specialist.",
episode_reward=reward,
pre_spawn_sim=0.3,
post_spawn_sim=0.7,
episode_idx=0,
)
def test_spawn_memory_record_and_size(tmp_path):
mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
emb = np.random.rand(384).astype(np.float32)
mem.record(_make_record(emb))
assert mem.size == 1
def test_spawn_memory_retrieve_similar_ordering(tmp_path):
mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
base = np.ones(384, dtype=np.float32)
# Record two spawns: one very similar to base, one orthogonal
similar_emb = base + np.random.rand(384).astype(np.float32) * 0.01
orthogonal_emb = np.zeros(384, dtype=np.float32)
orthogonal_emb[0] = 1.0
mem.record(_make_record(similar_emb, reward=0.5, sid="similar"))
mem.record(_make_record(orthogonal_emb, reward=0.5, sid="orthogonal"))
results = mem.retrieve_similar(base / np.linalg.norm(base), top_k=2)
assert results[0].specialist_id == "similar"
def test_spawn_memory_min_reward_filter(tmp_path):
mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
emb = np.ones(384, dtype=np.float32)
mem.record(_make_record(emb, reward=0.1, sid="low"))
mem.record(_make_record(emb, reward=0.8, sid="high"))
results = mem.retrieve_similar(emb / np.linalg.norm(emb), top_k=5, min_reward=0.5)
ids = [r.specialist_id for r in results]
assert "high" in ids
assert "low" not in ids
def test_spawn_memory_eviction_keeps_highest_reward(tmp_path):
mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"), max_entries=3)
emb = np.ones(384, dtype=np.float32)
for reward in [0.1, 0.9, 0.5, 0.8]:
mem.record(_make_record(emb, reward=reward))
assert mem.size == 3
rewards = {r.episode_reward for r in mem._records}
assert rewards == {0.9, 0.8, 0.5}