File size: 7,167 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Tests for SpecialistMemory, ResolutionBandit, and SpawnMemory."""
import numpy as np
import pytest

from agents.specialist_memory import SpecialistMemory
from agents.resolution_memory import ResolutionBandit, ResolutionOutcome
from training.spawn_memory import SpawnMemory, SpawnRecord


# ── SpecialistMemory ──────────────────────────────────────────────────────────

def test_specialist_memory_record_and_retrieve(tmp_path):
    mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
    mem.record("spec_a", "build an API", "Here is the API design.", reward=0.8)
    mem.record("spec_a", "write tests", "Here are the tests.", reward=0.5)
    assert mem.count("spec_a") == 2
    top = mem.get_top_examples("spec_a", n=2)
    assert top[0].reward == 0.8
    assert top[1].reward == 0.5


def test_specialist_memory_eviction(tmp_path):
    mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
    mem.MAX_PER_SPECIALIST = 5
    for i in range(7):
        mem.record("spec_b", f"task {i}", f"output {i}", reward=float(i))
    # Lowest-reward entries should be evicted; only 5 remain
    assert mem.count("spec_b") == 5
    # Remaining entries should all be the 5 highest-reward ones (rewards 2–6)
    rewards = {e.reward for e in mem.get_top_examples("spec_b", n=5)}
    assert rewards == {2.0, 3.0, 4.0, 5.0, 6.0}


def test_specialist_memory_top_examples_sorted(tmp_path):
    mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
    for reward in [0.3, 0.9, 0.1, 0.7]:
        mem.record("spec_c", "task", "output", reward=reward)
    top = mem.get_top_examples("spec_c", n=4)
    assert top[0].reward == 0.9
    assert top[-1].reward == 0.1


def test_specialist_memory_avg_reward(tmp_path):
    mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
    mem.record("spec_d", "t", "o", reward=0.4)
    mem.record("spec_d", "t", "o", reward=0.6)
    assert abs(mem.avg_reward("spec_d") - 0.5) < 1e-6


def test_specialist_memory_empty_specialist(tmp_path):
    mem = SpecialistMemory(path=str(tmp_path / "mem.json"))
    assert mem.count("nobody") == 0
    assert mem.avg_reward("nobody") == 0.0
    assert mem.get_top_examples("nobody") == []


# ── ResolutionBandit ──────────────────────────────────────────────────────────

_TEMPLATES = {
    "technical": {"standard": "Use {a}.", "defer_to_a": "Defer to {a}."},
    "factual":   {"recency": "Use recent claim from {a}."},
}


def test_resolution_bandit_returns_valid_key(tmp_path):
    bandit = ResolutionBandit(
        templates=_TEMPLATES,
        config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 1},
        memory_path=str(tmp_path / "res.jsonl"),
    )
    key = bandit.select_template("technical")
    assert key in _TEMPLATES["technical"]


def test_resolution_bandit_exploits_best_arm(tmp_path):
    bandit = ResolutionBandit(
        templates=_TEMPLATES,
        config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 2},
        memory_path=str(tmp_path / "res.jsonl"),
    )
    # Seed defer_to_a with high deltas, standard with low
    for _ in range(3):
        bandit.record_outcome(ResolutionOutcome("technical", "defer_to_a", 0.9, 0))
        bandit.record_outcome(ResolutionOutcome("technical", "standard",   0.1, 0))
    assert bandit.select_template("technical") == "defer_to_a"


def test_resolution_bandit_random_when_insufficient_samples(tmp_path):
    bandit = ResolutionBandit(
        templates=_TEMPLATES,
        config={"resolution_bandit_epsilon": 0.0, "resolution_bandit_min_samples": 10},
        memory_path=str(tmp_path / "res.jsonl"),
    )
    # Only 2 samples β€” below min_samples of 10, so should still return a valid key
    bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.8, 0))
    bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.7, 0))
    key = bandit.select_template("technical")
    assert key in _TEMPLATES["technical"]


def test_resolution_bandit_arm_means(tmp_path):
    bandit = ResolutionBandit(
        templates=_TEMPLATES,
        config={},
        memory_path=str(tmp_path / "res.jsonl"),
    )
    bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.4, 0))
    bandit.record_outcome(ResolutionOutcome("technical", "standard", 0.6, 0))
    means = bandit.arm_means()
    assert abs(means["technical"]["standard"] - 0.5) < 1e-6


def test_resolution_bandit_unknown_type_returns_default(tmp_path):
    bandit = ResolutionBandit(
        templates=_TEMPLATES,
        config={},
        memory_path=str(tmp_path / "res.jsonl"),
    )
    assert bandit.select_template("nonexistent_type") == "default"


# ── SpawnMemory ───────────────────────────────────────────────────────────────

def _make_record(task_emb, reward=0.5, sid="spec_x"):
    return SpawnRecord(
        task_embedding=task_emb.tolist(),
        task_description="test task",
        specialist_id=sid,
        specialist_role="Test Role",
        specialist_desc="A test specialist.",
        episode_reward=reward,
        pre_spawn_sim=0.3,
        post_spawn_sim=0.7,
        episode_idx=0,
    )


def test_spawn_memory_record_and_size(tmp_path):
    mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
    emb = np.random.rand(384).astype(np.float32)
    mem.record(_make_record(emb))
    assert mem.size == 1


def test_spawn_memory_retrieve_similar_ordering(tmp_path):
    mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
    base = np.ones(384, dtype=np.float32)
    # Record two spawns: one very similar to base, one orthogonal
    similar_emb = base + np.random.rand(384).astype(np.float32) * 0.01
    orthogonal_emb = np.zeros(384, dtype=np.float32)
    orthogonal_emb[0] = 1.0
    mem.record(_make_record(similar_emb, reward=0.5, sid="similar"))
    mem.record(_make_record(orthogonal_emb, reward=0.5, sid="orthogonal"))
    results = mem.retrieve_similar(base / np.linalg.norm(base), top_k=2)
    assert results[0].specialist_id == "similar"


def test_spawn_memory_min_reward_filter(tmp_path):
    mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"))
    emb = np.ones(384, dtype=np.float32)
    mem.record(_make_record(emb, reward=0.1, sid="low"))
    mem.record(_make_record(emb, reward=0.8, sid="high"))
    results = mem.retrieve_similar(emb / np.linalg.norm(emb), top_k=5, min_reward=0.5)
    ids = [r.specialist_id for r in results]
    assert "high" in ids
    assert "low" not in ids


def test_spawn_memory_eviction_keeps_highest_reward(tmp_path):
    mem = SpawnMemory(path=str(tmp_path / "spawn.jsonl"), max_entries=3)
    emb = np.ones(384, dtype=np.float32)
    for reward in [0.1, 0.9, 0.5, 0.8]:
        mem.record(_make_record(emb, reward=reward))
    assert mem.size == 3
    rewards = {r.episode_reward for r in mem._records}
    assert rewards == {0.9, 0.8, 0.5}