"""Tests for MoEGraph top-k parallel expert routing.""" import os, torch from arbitor.components import MoEGraph from arbitor.config import HIDDEN_DIM def test_moegraph_topk_params(): mg = MoEGraph(top_k=8) assert mg.top_k == 8 assert len(mg.W_gate) == 256 assert mg.codebook_up is not None def test_moegraph_topk_forward_shape(): mg = MoEGraph(top_k=8) B, T = 1, 4 x = torch.randn(B, T, HIDDEN_DIM) vq = torch.randint(0, 1000, (B, T)) out, ponder = mg(x, vq) assert out.shape == (B, T, HIDDEN_DIM) assert torch.isfinite(out).all() assert torch.isfinite(ponder) assert ponder >= 0 def test_moegraph_top1_vs_topk_shape(): mg1 = MoEGraph(top_k=1) mg8 = MoEGraph(top_k=8) B, T = 1, 4 x = torch.randn(B, T, HIDDEN_DIM) vq = torch.randint(0, 1000, (B, T)) out1, _ = mg1(x, vq) out8, _ = mg8(x, vq) assert out1.shape == out8.shape def test_moegraph_topk_no_dead_code(): from arbitor.main import ARBModel model = ARBModel(enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True) mg = model.moegraph assert mg.top_k == 4 assert mg.num_experts == 256 assert mg.cb_dim == 768 def test_moegraph_topk_routing_logic(): torch.manual_seed(42) B, T, E, k = 2, 4, 256, 8 scores = torch.randn(B, T, E) scores_topk, idx = scores.topk(k=k, dim=-1) weights = torch.softmax(scores_topk / 0.1, dim=-1) assert weights.shape == (B, T, k) assert idx.shape == (B, T, k) assert torch.allclose(weights.sum(dim=-1), torch.ones(B, T)) for b in range(B): for t in range(T): assert len(torch.unique(idx[b, t])) == k