SHOREKEEPER / tests /test_core.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
import json
import tempfile
import pytest
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# AttentionLayer
# ---------------------------------------------------------------------------
def make_attn_cfg():
return {
"dim": 64, "n_heads": 4, "n_kv_heads": 2,
"head_dim": 16, "seq_len": 128, "rope_theta": 10000.0,
}
def test_attention_layer_output_shape():
from src.council.attention import AttentionLayer
layer = AttentionLayer(make_attn_cfg())
x = torch.randn(2, 10, 64) # (batch, seq, dim)
out = layer(x)
assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
def test_attention_layer_residual():
"""Output should differ from input (attention actually transforms)."""
from src.council.attention import AttentionLayer
layer = AttentionLayer(make_attn_cfg())
x = torch.randn(1, 8, 64)
out = layer(x)
assert not torch.allclose(out, x), "AttentionLayer output identical to input — residual not working"
# ---------------------------------------------------------------------------
# Expert dispatch weights
# ---------------------------------------------------------------------------
def test_expert_dispatch_weights_per_token():
"""Each masked token must receive a (n_masked, 1) weight, not a scalar."""
from src.council.sentinel import Sentinel
from src.council.base_expert import BaseExpert
dim, n_experts, n_activated = 64, 4, 2
sentinel = Sentinel(dim, n_experts, n_activated)
expert = BaseExpert(dim, 128, "test", "test")
x = torch.randn(8, dim) # 8 tokens flat
weights, indices = sentinel(x)
for i in range(n_experts):
mask = (indices == i).any(dim=-1)
if not mask.any():
continue
expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True)
expert_out = expert(x[mask])
assert expert_weights.shape == (mask.sum(), 1), \
f"Expert weights shape {expert_weights.shape} != ({mask.sum()}, 1)"
assert expert_out.shape == (mask.sum(), dim), \
f"Expert out shape {expert_out.shape} != ({mask.sum()}, {dim})"
# Weighted contribution must be broadcastable
_ = expert_out * expert_weights # should not raise
# ---------------------------------------------------------------------------
# JSONLibrary
# ---------------------------------------------------------------------------
def test_json_library_store_and_recall():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
mid = lib.store({"fact": "Paris is the capital of France"}, "important_facts")
assert isinstance(mid, str) and len(mid) == 8
results = lib.recall("Paris")
assert len(results) == 1
assert results[0]["id"] == mid
def test_json_library_access_count_written_to_disk():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
lib.store({"fact": "test"}, "important_facts")
lib.recall("test")
path = lib._get_category_path("important_facts")
with open(path) as f:
data = json.load(f)
assert data[0]["access_count"] == 1, "access_count not written back to disk"
assert data[0]["last_accessed"] is not None
def test_json_library_store_cap():
from src.memory.json_library import JSONLibrary
with tempfile.TemporaryDirectory() as tmp:
lib = JSONLibrary(tmp)
cap = JSONLibrary.MAX_ENTRIES_PER_CATEGORY
for i in range(cap + 10):
lib.store({"i": i}, "important_facts")
path = lib._get_category_path("important_facts")
with open(path) as f:
data = json.load(f)
assert len(data) <= cap, f"Store exceeded cap: {len(data)} > {cap}"
# ---------------------------------------------------------------------------
# GRPO reward
# ---------------------------------------------------------------------------
def test_grpo_reward_correct_answer():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer) # skip __init__
reward = trainer.compute_reward("Step 1: 2+2=4\nThe answer is 42", "42")
assert reward >= 2.0, "Correct answer should yield reward >= 2.0"
def test_grpo_reward_wrong_answer():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
reward = trainer.compute_reward("The answer is 99", "42")
assert reward == 0.0, "Wrong answer should yield 0 reward"
def test_grpo_reward_no_proxies():
"""Long verbose wrong answers should not earn reward."""
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
long_wrong = " ".join(["word"] * 50) + " 999"
reward = trainer.compute_reward(long_wrong, "42")
assert reward == 0.0, "Proxy metrics (length/diversity) should not grant reward"
def test_grpo_reward_chain_of_thought_bonus():
from src.training.grpo import GRPOTrainer
trainer = GRPOTrainer.__new__(GRPOTrainer)
cot = "x = 6\ny = 7\nx * y = 42"
reward = trainer.compute_reward(cot, "42")
assert reward > 2.0, "Correct answer with chain-of-thought should exceed base reward"