File size: 5,375 Bytes
73400c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import json
import tempfile
import pytest
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# AttentionLayer
# ---------------------------------------------------------------------------
def make_attn_cfg():
return {
"dim": 64, "n_heads": 4, "n_kv_heads": 2,
"head_dim": 16, "seq_len": 128, "rope_theta": 10000.0,
}
def test_attention_layer_output_shape():
from src.council.attention import AttentionLayer
layer = AttentionLayer(make_attn_cfg())
x = torch.randn(2, 10, 64) # (batch, seq, dim)
out = layer(x)
assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
def test_attention_layer_residual():
"""Output should differ from input (attention actually transforms)."""
from src.council.attention import AttentionLayer
layer = AttentionLayer(make_attn_cfg())
x = torch.randn(1, 8, 64)
out = layer(x)
assert not torch.allclose(out, x), "AttentionLayer output identical to input — residual not working"
# ---------------------------------------------------------------------------
# Expert dispatch weights
# ---------------------------------------------------------------------------
def test_expert_dispatch_weights_per_token():
"""Each masked token must receive a (n_masked, 1) weight, not a scalar."""
from src.council.sentinel import Sentinel
from src.council.base_expert import BaseExpert
dim, n_experts, n_activated = 64, 4, 2
sentinel = Sentinel(dim, n_experts, n_activated)
expert = BaseExpert(dim, 128, "test", "test")
x = torch.randn(8, dim) # 8 tokens flat
weights, indices = sentinel(x)
for i in range(n_experts):
mask = (indices == i).any(dim=-1)
if not mask.any():
continue
expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True)
expert_out = expert(x[mask])
assert expert_weights.shape == (mask.sum(), 1), \
f"Expert weights shape {expert_weights.shape} != ({mask.sum()}, 1)"
assert expert_out.shape == (mask.sum(), dim), \
f"Expert out shape {expert_out.shape} != ({mask.sum()}, {dim})"
# Weighted contribution must be broadcastable
_ = expert_out * expert_weights # should not raise
# ---------------------------------------------------------------------------
# JSONLibrary
# ---------------------------------------------------------------------------
def test_json_library_store_and_recall():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
mid = lib.store({"fact": "Paris is the capital of France"}, "important_facts")
assert isinstance(mid, str) and len(mid) == 8
results = lib.recall("Paris")
assert len(results) == 1
assert results[0]["id"] == mid
def test_json_library_access_count_written_to_disk():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
lib.store({"fact": "test"}, "important_facts")
lib.recall("test")
path = lib._get_category_path("important_facts")
with open(path) as f:
data = json.load(f)
assert data[0]["access_count"] == 1, "access_count not written back to disk"
assert data[0]["last_accessed"] is not None
def test_json_library_store_cap():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
cap = JSONLibrary.MAX_ENTRIES_PER_CATEGORY
for i in range(cap + 10):
lib.store({"i": i}, "important_facts")
path = lib._get_category_path("important_facts")
with open(path) as f:
data = json.load(f)
assert len(data) <= cap, f"Store exceeded cap: {len(data)} > {cap}"
# ---------------------------------------------------------------------------
# GRPO reward
# ---------------------------------------------------------------------------
def test_grpo_reward_correct_answer():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer) # skip __init__
reward = trainer.compute_reward("Step 1: 2+2=4\nThe answer is 42", "42")
assert reward >= 2.0, "Correct answer should yield reward >= 2.0"
def test_grpo_reward_wrong_answer():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
reward = trainer.compute_reward("The answer is 99", "42")
assert reward == 0.0, "Wrong answer should yield 0 reward"
def test_grpo_reward_no_proxies():
"""Long verbose wrong answers should not earn reward."""
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
long_wrong = " ".join(["word"] * 50) + " 999"
reward = trainer.compute_reward(long_wrong, "42")
assert reward == 0.0, "Proxy metrics (length/diversity) should not grant reward"
def test_grpo_reward_chain_of_thought_bonus():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
cot = "x = 6\ny = 7\nx * y = 42"
reward = trainer.compute_reward(cot, "42")
assert reward > 2.0, "Correct answer with chain-of-thought should exceed base reward"
|