File size: 3,268 Bytes
5038ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
test_shapes.py — assert every module output shape is correct.

Run: pytest tests/test_shapes.py -v
"""
import torch
import pytest

from tmt.model.config import TMTConfig
from tmt.model.embedding import TokenEmbedding, TemporalPositionEncoder
from tmt.model.mesh import MeshBuilder
from tmt.model.attention import MeshAttention
from tmt.model.ffn import DualStreamFFN
from tmt.model.exit_gate import ExitGate
from tmt.model.memory import MemoryAnchorCross
from tmt.model.layers import TMTLayer


B, S = 2, 32   # small batch and sequence for fast tests
CFG = TMTConfig(
    vocab_size=1000,
    d_model=64,
    n_heads=4,
    n_layers=2,
    max_seq_len=64,
    graph_k=4,
    ffn_stream_dim=32,
    memory_anchors=4,
)


@pytest.fixture
def token_ids():
    return torch.randint(0, CFG.vocab_size, (B, S))


@pytest.fixture
def x():
    return torch.randn(B, S, CFG.d_model)


@pytest.fixture
def graph(x):
    mb = MeshBuilder(CFG.graph_k)
    x_flat = x.reshape(B * S, CFG.d_model)
    return mb(x_flat, B, S)


def test_token_embedding(token_ids):
    emb = TokenEmbedding(CFG)
    out = emb(token_ids)
    assert out.shape == (B, S, CFG.d_model), f"Expected ({B}, {S}, {CFG.d_model}), got {out.shape}"


def test_temporal_position_encoder(x):
    enc = TemporalPositionEncoder(CFG)
    out, decay = enc(x)
    assert out.shape == (B, S, CFG.d_model)
    assert decay.shape == (B, S, CFG.d_model)
    # Decay scalars should be in (0, 1)
    assert decay.min() >= 0.0 and decay.max() <= 1.0


def test_mesh_builder(x):
    mb = MeshBuilder(CFG.graph_k)
    x_flat = x.reshape(B * S, CFG.d_model)
    edge_index, edge_weight = mb(x_flat, B, S)
    # edge_index: (2, E) where E = B * S * k
    assert edge_index.shape[0] == 2
    assert edge_index.shape[1] > 0
    assert edge_weight.shape[0] == edge_index.shape[1]
    # All node indices in range [0, B*S)
    assert edge_index.max() < B * S
    assert edge_index.min() >= 0


def test_mesh_attention(x, graph):
    edge_index, edge_weight = graph
    attn = MeshAttention(CFG)
    out = attn(x, edge_index, edge_weight)
    assert out.shape == (B, S, CFG.d_model)


def test_dual_stream_ffn(x):
    ffn = DualStreamFFN(CFG)
    out = ffn(x)
    assert out.shape == (B, S, CFG.d_model)


def test_exit_gate(x):
    gate = ExitGate(CFG)
    exit_mask = torch.zeros(B, S, dtype=torch.bool)
    out, new_mask, confidence = gate(x, exit_mask)
    assert out.shape == (B, S, CFG.d_model)
    assert new_mask.shape == (B, S)
    assert new_mask.dtype == torch.bool
    assert confidence.shape == (B, S)
    assert confidence.min() >= 0.0 and confidence.max() <= 1.0


def test_memory_anchor_cross(x):
    mac = MemoryAnchorCross(CFG)
    out, mem_state = mac(x)
    assert out.shape == (B, S, CFG.d_model)
    assert mem_state.shape == (CFG.memory_anchors, CFG.d_model)


def test_tmt_layer(x, graph):
    edge_index, edge_weight = graph
    layer = TMTLayer(CFG, layer_idx=0)
    exit_mask = torch.zeros(B, S, dtype=torch.bool)
    x_out, new_mask, confidence, mem_state = layer(
        x, edge_index, edge_weight, exit_mask
    )
    assert x_out.shape == (B, S, CFG.d_model)
    assert new_mask.shape == (B, S)
    assert confidence.shape == (B, S)
    assert mem_state.shape == (CFG.memory_anchors, CFG.d_model)