import json import tempfile import pytest import torch import torch.nn as nn # --------------------------------------------------------------------------- # AttentionLayer # --------------------------------------------------------------------------- def make_attn_cfg(): return { "dim": 64, "n_heads": 4, "n_kv_heads": 2, "head_dim": 16, "seq_len": 128, "rope_theta": 10000.0, } def test_attention_layer_output_shape(): from src.council.attention import AttentionLayer layer = AttentionLayer(make_attn_cfg()) x = torch.randn(2, 10, 64) # (batch, seq, dim) out = layer(x) assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" def test_attention_layer_residual(): """Output should differ from input (attention actually transforms).""" from src.council.attention import AttentionLayer layer = AttentionLayer(make_attn_cfg()) x = torch.randn(1, 8, 64) out = layer(x) assert not torch.allclose(out, x), "AttentionLayer output identical to input — residual not working" # --------------------------------------------------------------------------- # Expert dispatch weights # --------------------------------------------------------------------------- def test_expert_dispatch_weights_per_token(): """Each masked token must receive a (n_masked, 1) weight, not a scalar.""" from src.council.sentinel import Sentinel from src.council.base_expert import BaseExpert dim, n_experts, n_activated = 64, 4, 2 sentinel = Sentinel(dim, n_experts, n_activated) expert = BaseExpert(dim, 128, "test", "test") x = torch.randn(8, dim) # 8 tokens flat weights, indices = sentinel(x) for i in range(n_experts): mask = (indices == i).any(dim=-1) if not mask.any(): continue expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True) expert_out = expert(x[mask]) assert expert_weights.shape == (mask.sum(), 1), \ f"Expert weights shape {expert_weights.shape} != ({mask.sum()}, 1)" assert expert_out.shape == (mask.sum(), dim), \ f"Expert out shape {expert_out.shape} != ({mask.sum()}, {dim})" # Weighted contribution must be broadcastable _ = expert_out * expert_weights # should not raise # --------------------------------------------------------------------------- # JSONLibrary # --------------------------------------------------------------------------- def test_json_library_store_and_recall(): from src.memory.json_library import JSONLibrary with tempfile.TemporaryDirectory() as tmp: lib = JSONLibrary(tmp) mid = lib.store({"fact": "Paris is the capital of France"}, "important_facts") assert isinstance(mid, str) and len(mid) == 8 results = lib.recall("Paris") assert len(results) == 1 assert results[0]["id"] == mid def test_json_library_access_count_written_to_disk(): from src.memory.json_library import JSONLibrary with tempfile.TemporaryDirectory() as tmp: lib = JSONLibrary(tmp) lib.store({"fact": "test"}, "important_facts") lib.recall("test") path = lib._get_category_path("important_facts") with open(path) as f: data = json.load(f) assert data[0]["access_count"] == 1, "access_count not written back to disk" assert data[0]["last_accessed"] is not None def test_json_library_store_cap(): from src.memory.json_library import JSONLibrary with tempfile.TemporaryDirectory() as tmp: lib = JSONLibrary(tmp) cap = JSONLibrary.MAX_ENTRIES_PER_CATEGORY for i in range(cap + 10): lib.store({"i": i}, "important_facts") path = lib._get_category_path("important_facts") with open(path) as f: data = json.load(f) assert len(data) <= cap, f"Store exceeded cap: {len(data)} > {cap}" # --------------------------------------------------------------------------- # GRPO reward # --------------------------------------------------------------------------- def test_grpo_reward_correct_answer(): from src.training.grpo import GRPOTrainer trainer = GRPOTrainer.__new__(GRPOTrainer) # skip __init__ reward = trainer.compute_reward("Step 1: 2+2=4\nThe answer is 42", "42") assert reward >= 2.0, "Correct answer should yield reward >= 2.0" def test_grpo_reward_wrong_answer(): from src.training.grpo import GRPOTrainer trainer = GRPOTrainer.__new__(GRPOTrainer) reward = trainer.compute_reward("The answer is 99", "42") assert reward == 0.0, "Wrong answer should yield 0 reward" def test_grpo_reward_no_proxies(): """Long verbose wrong answers should not earn reward.""" from src.training.grpo import GRPOTrainer trainer = GRPOTrainer.__new__(GRPOTrainer) long_wrong = " ".join(["word"] * 50) + " 999" reward = trainer.compute_reward(long_wrong, "42") assert reward == 0.0, "Proxy metrics (length/diversity) should not grant reward" def test_grpo_reward_chain_of_thought_bonus(): from src.training.grpo import GRPOTrainer trainer = GRPOTrainer.__new__(GRPOTrainer) cot = "x = 6\ny = 7\nx * y = 42" reward = trainer.compute_reward(cot, "42") assert reward > 2.0, "Correct answer with chain-of-thought should exceed base reward"