| import json |
| import tempfile |
| import pytest |
| import torch |
| import torch.nn as nn |
|
|
|
|
| |
| |
| |
|
|
| def make_attn_cfg(): |
| return { |
| "dim": 64, "n_heads": 4, "n_kv_heads": 2, |
| "head_dim": 16, "seq_len": 128, "rope_theta": 10000.0, |
| } |
|
|
|
|
| def test_attention_layer_output_shape(): |
| from src.council.attention import AttentionLayer |
| layer = AttentionLayer(make_attn_cfg()) |
| x = torch.randn(2, 10, 64) |
| out = layer(x) |
| assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" |
|
|
|
|
| def test_attention_layer_residual(): |
| """Output should differ from input (attention actually transforms).""" |
| from src.council.attention import AttentionLayer |
| layer = AttentionLayer(make_attn_cfg()) |
| x = torch.randn(1, 8, 64) |
| out = layer(x) |
| assert not torch.allclose(out, x), "AttentionLayer output identical to input — residual not working" |
|
|
|
|
| |
| |
| |
|
|
| def test_expert_dispatch_weights_per_token(): |
| """Each masked token must receive a (n_masked, 1) weight, not a scalar.""" |
| from src.council.sentinel import Sentinel |
| from src.council.base_expert import BaseExpert |
|
|
| dim, n_experts, n_activated = 64, 4, 2 |
| sentinel = Sentinel(dim, n_experts, n_activated) |
| expert = BaseExpert(dim, 128, "test", "test") |
|
|
| x = torch.randn(8, dim) |
| weights, indices = sentinel(x) |
|
|
| for i in range(n_experts): |
| mask = (indices == i).any(dim=-1) |
| if not mask.any(): |
| continue |
| expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True) |
| expert_out = expert(x[mask]) |
| assert expert_weights.shape == (mask.sum(), 1), \ |
| f"Expert weights shape {expert_weights.shape} != ({mask.sum()}, 1)" |
| assert expert_out.shape == (mask.sum(), dim), \ |
| f"Expert out shape {expert_out.shape} != ({mask.sum()}, {dim})" |
| |
| _ = expert_out * expert_weights |
|
|
|
|
| |
| |
| |
|
|
| def test_json_library_store_and_recall(): |
| from src.memory.json_library import JSONLibrary |
| with tempfile.TemporaryDirectory() as tmp: |
| lib = JSONLibrary(tmp) |
| mid = lib.store({"fact": "Paris is the capital of France"}, "important_facts") |
| assert isinstance(mid, str) and len(mid) == 8 |
|
|
| results = lib.recall("Paris") |
| assert len(results) == 1 |
| assert results[0]["id"] == mid |
|
|
|
|
| def test_json_library_access_count_written_to_disk(): |
| from src.memory.json_library import JSONLibrary |
| with tempfile.TemporaryDirectory() as tmp: |
| lib = JSONLibrary(tmp) |
| lib.store({"fact": "test"}, "important_facts") |
| lib.recall("test") |
|
|
| path = lib._get_category_path("important_facts") |
| with open(path) as f: |
| data = json.load(f) |
| assert data[0]["access_count"] == 1, "access_count not written back to disk" |
| assert data[0]["last_accessed"] is not None |
|
|
|
|
| def test_json_library_store_cap(): |
| from src.memory.json_library import JSONLibrary |
| with tempfile.TemporaryDirectory() as tmp: |
| lib = JSONLibrary(tmp) |
| cap = JSONLibrary.MAX_ENTRIES_PER_CATEGORY |
| for i in range(cap + 10): |
| lib.store({"i": i}, "important_facts") |
|
|
| path = lib._get_category_path("important_facts") |
| with open(path) as f: |
| data = json.load(f) |
| assert len(data) <= cap, f"Store exceeded cap: {len(data)} > {cap}" |
|
|
|
|
| |
| |
| |
|
|
| def test_grpo_reward_correct_answer(): |
| from src.training.grpo import GRPOTrainer |
| trainer = GRPOTrainer.__new__(GRPOTrainer) |
| reward = trainer.compute_reward("Step 1: 2+2=4\nThe answer is 42", "42") |
| assert reward >= 2.0, "Correct answer should yield reward >= 2.0" |
|
|
|
|
| def test_grpo_reward_wrong_answer(): |
| from src.training.grpo import GRPOTrainer |
| trainer = GRPOTrainer.__new__(GRPOTrainer) |
| reward = trainer.compute_reward("The answer is 99", "42") |
| assert reward == 0.0, "Wrong answer should yield 0 reward" |
|
|
|
|
| def test_grpo_reward_no_proxies(): |
| """Long verbose wrong answers should not earn reward.""" |
| from src.training.grpo import GRPOTrainer |
| trainer = GRPOTrainer.__new__(GRPOTrainer) |
| long_wrong = " ".join(["word"] * 50) + " 999" |
| reward = trainer.compute_reward(long_wrong, "42") |
| assert reward == 0.0, "Proxy metrics (length/diversity) should not grant reward" |
|
|
|
|
| def test_grpo_reward_chain_of_thought_bonus(): |
| from src.training.grpo import GRPOTrainer |
| trainer = GRPOTrainer.__new__(GRPOTrainer) |
| cot = "x = 6\ny = 7\nx * y = 42" |
| reward = trainer.compute_reward(cot, "42") |
| assert reward > 2.0, "Correct answer with chain-of-thought should exceed base reward" |
|
|