| """Tests for MoEGraph top-k parallel expert routing.""" |
| import os, torch |
| from arbitor.components import MoEGraph |
| from arbitor.config import HIDDEN_DIM |
|
|
|
|
| def test_moegraph_topk_params(): |
| mg = MoEGraph(top_k=8) |
| assert mg.top_k == 8 |
| assert len(mg.W_gate) == 256 |
| assert mg.codebook_up is not None |
|
|
|
|
| def test_moegraph_topk_forward_shape(): |
| mg = MoEGraph(top_k=8) |
| B, T = 1, 4 |
| x = torch.randn(B, T, HIDDEN_DIM) |
| vq = torch.randint(0, 1000, (B, T)) |
| out, ponder = mg(x, vq) |
| assert out.shape == (B, T, HIDDEN_DIM) |
| assert torch.isfinite(out).all() |
| assert torch.isfinite(ponder) |
| assert ponder >= 0 |
|
|
|
|
| def test_moegraph_top1_vs_topk_shape(): |
| mg1 = MoEGraph(top_k=1) |
| mg8 = MoEGraph(top_k=8) |
| B, T = 1, 4 |
| x = torch.randn(B, T, HIDDEN_DIM) |
| vq = torch.randint(0, 1000, (B, T)) |
| out1, _ = mg1(x, vq) |
| out8, _ = mg8(x, vq) |
| assert out1.shape == out8.shape |
|
|
|
|
| def test_moegraph_topk_no_dead_code(): |
| from arbitor.main import ARBModel |
| model = ARBModel(enable_image=False, enable_audio=False, enable_vq=True, enable_graph=True) |
| mg = model.moegraph |
| assert mg.top_k == 4 |
| assert mg.num_experts == 256 |
| assert mg.cb_dim == 768 |
|
|
|
|
| def test_moegraph_topk_routing_logic(): |
| torch.manual_seed(42) |
| B, T, E, k = 2, 4, 256, 8 |
| scores = torch.randn(B, T, E) |
| scores_topk, idx = scores.topk(k=k, dim=-1) |
| weights = torch.softmax(scores_topk / 0.1, dim=-1) |
| assert weights.shape == (B, T, k) |
| assert idx.shape == (B, T, k) |
| assert torch.allclose(weights.sum(dim=-1), torch.ones(B, T)) |
| for b in range(B): |
| for t in range(T): |
| assert len(torch.unique(idx[b, t])) == k |
|
|