ARBS / tests /test_moegraph_topk.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Tests for MoEGraph top-k parallel expert routing."""
import os, torch
from arbitor.components import MoEGraph
from arbitor.config import HIDDEN_DIM
def test_moegraph_topk_params():
mg = MoEGraph(top_k=8)
assert mg.top_k == 8
assert len(mg.W_gate) == 256
assert mg.codebook_up is not None
def test_moegraph_topk_forward_shape():
mg = MoEGraph(top_k=8)
B, T = 1, 4
x = torch.randn(B, T, HIDDEN_DIM)
vq = torch.randint(0, 1000, (B, T))
out, ponder = mg(x, vq)
assert out.shape == (B, T, HIDDEN_DIM)
assert torch.isfinite(out).all()
assert torch.isfinite(ponder)
assert ponder >= 0
def test_moegraph_top1_vs_topk_shape():
mg1 = MoEGraph(top_k=1)
mg8 = MoEGraph(top_k=8)
B, T = 1, 4
x = torch.randn(B, T, HIDDEN_DIM)
vq = torch.randint(0, 1000, (B, T))
out1, _ = mg1(x, vq)
out8, _ = mg8(x, vq)
assert out1.shape == out8.shape
def test_moegraph_topk_no_dead_code():
from arbitor.main import ARBModel
model = ARBModel(enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True)
mg = model.moegraph
assert mg.top_k == 4
assert mg.num_experts == 256
assert mg.cb_dim == 768
def test_moegraph_topk_routing_logic():
torch.manual_seed(42)
B, T, E, k = 2, 4, 256, 8
scores = torch.randn(B, T, E)
scores_topk, idx = scores.topk(k=k, dim=-1)
weights = torch.softmax(scores_topk / 0.1, dim=-1)
assert weights.shape == (B, T, k)
assert idx.shape == (B, T, k)
assert torch.allclose(weights.sum(dim=-1), torch.ones(B, T))
for b in range(B):
for t in range(T):
assert len(torch.unique(idx[b, t])) == k