File size: 5,375 Bytes
73400c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import tempfile
import pytest
import torch
import torch.nn as nn


# ---------------------------------------------------------------------------
# AttentionLayer
# ---------------------------------------------------------------------------

def make_attn_cfg():
    return {
        "dim": 64, "n_heads": 4, "n_kv_heads": 2,
        "head_dim": 16, "seq_len": 128, "rope_theta": 10000.0,
    }


def test_attention_layer_output_shape():
    from src.council.attention import AttentionLayer
    layer = AttentionLayer(make_attn_cfg())
    x = torch.randn(2, 10, 64)  # (batch, seq, dim)
    out = layer(x)
    assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"


def test_attention_layer_residual():
    """Output should differ from input (attention actually transforms)."""
    from src.council.attention import AttentionLayer
    layer = AttentionLayer(make_attn_cfg())
    x = torch.randn(1, 8, 64)
    out = layer(x)
    assert not torch.allclose(out, x), "AttentionLayer output identical to input — residual not working"


# ---------------------------------------------------------------------------
# Expert dispatch weights
# ---------------------------------------------------------------------------

def test_expert_dispatch_weights_per_token():
    """Each masked token must receive a (n_masked, 1) weight, not a scalar."""
    from src.council.sentinel import Sentinel
    from src.council.base_expert import BaseExpert

    dim, n_experts, n_activated = 64, 4, 2
    sentinel = Sentinel(dim, n_experts, n_activated)
    expert = BaseExpert(dim, 128, "test", "test")

    x = torch.randn(8, dim)  # 8 tokens flat
    weights, indices = sentinel(x)

    for i in range(n_experts):
        mask = (indices == i).any(dim=-1)
        if not mask.any():
            continue
        expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True)
        expert_out = expert(x[mask])
        assert expert_weights.shape == (mask.sum(), 1), \
            f"Expert weights shape {expert_weights.shape} != ({mask.sum()}, 1)"
        assert expert_out.shape == (mask.sum(), dim), \
            f"Expert out shape {expert_out.shape} != ({mask.sum()}, {dim})"
        # Weighted contribution must be broadcastable
        _ = expert_out * expert_weights  # should not raise


# ---------------------------------------------------------------------------
# JSONLibrary
# ---------------------------------------------------------------------------

def test_json_library_store_and_recall():
    from src.memory.json_library import JSONLibrary
    with tempfile.TemporaryDirectory() as tmp:
        lib = JSONLibrary(tmp)
        mid = lib.store({"fact": "Paris is the capital of France"}, "important_facts")
        assert isinstance(mid, str) and len(mid) == 8

        results = lib.recall("Paris")
        assert len(results) == 1
        assert results[0]["id"] == mid


def test_json_library_access_count_written_to_disk():
    from src.memory.json_library import JSONLibrary
    with tempfile.TemporaryDirectory() as tmp:
        lib = JSONLibrary(tmp)
        lib.store({"fact": "test"}, "important_facts")
        lib.recall("test")

        path = lib._get_category_path("important_facts")
        with open(path) as f:
            data = json.load(f)
        assert data[0]["access_count"] == 1, "access_count not written back to disk"
        assert data[0]["last_accessed"] is not None


def test_json_library_store_cap():
    from src.memory.json_library import JSONLibrary
    with tempfile.TemporaryDirectory() as tmp:
        lib = JSONLibrary(tmp)
        cap = JSONLibrary.MAX_ENTRIES_PER_CATEGORY
        for i in range(cap + 10):
            lib.store({"i": i}, "important_facts")

        path = lib._get_category_path("important_facts")
        with open(path) as f:
            data = json.load(f)
        assert len(data) <= cap, f"Store exceeded cap: {len(data)} > {cap}"


# ---------------------------------------------------------------------------
# GRPO reward
# ---------------------------------------------------------------------------

def test_grpo_reward_correct_answer():
    from src.training.grpo import GRPOTrainer
    trainer = GRPOTrainer.__new__(GRPOTrainer)  # skip __init__
    reward = trainer.compute_reward("Step 1: 2+2=4\nThe answer is 42", "42")
    assert reward >= 2.0, "Correct answer should yield reward >= 2.0"


def test_grpo_reward_wrong_answer():
    from src.training.grpo import GRPOTrainer
    trainer = GRPOTrainer.__new__(GRPOTrainer)
    reward = trainer.compute_reward("The answer is 99", "42")
    assert reward == 0.0, "Wrong answer should yield 0 reward"


def test_grpo_reward_no_proxies():
    """Long verbose wrong answers should not earn reward."""
    from src.training.grpo import GRPOTrainer
    trainer = GRPOTrainer.__new__(GRPOTrainer)
    long_wrong = " ".join(["word"] * 50) + " 999"
    reward = trainer.compute_reward(long_wrong, "42")
    assert reward == 0.0, "Proxy metrics (length/diversity) should not grant reward"


def test_grpo_reward_chain_of_thought_bonus():
    from src.training.grpo import GRPOTrainer
    trainer = GRPOTrainer.__new__(GRPOTrainer)
    cot = "x = 6\ny = 7\nx * y = 42"
    reward = trainer.compute_reward(cot, "42")
    assert reward > 2.0, "Correct answer with chain-of-thought should exceed base reward"