Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""Unit tests for Mixture of Experts components"""
import pytest
import torch
import torch.nn as nn
from src.models.moe_advanced import (
ExpertConfig,
Expert,
TopKRouter,
MoELayer,
LoadBalancingLoss
)
class TestExpertConfig:
"""Test expert configuration"""
def test_default_config(self):
"""Test default expert configuration"""
config = ExpertConfig()
assert config.num_knowledge_experts == 8
assert config.num_skill_experts == 8
assert config.num_meta_experts == 4
assert config.top_k == 2
def test_custom_config(self):
"""Test custom expert configuration"""
config = ExpertConfig(
num_knowledge_experts=16,
top_k=4,
expert_capacity=256
)
assert config.num_knowledge_experts == 16
assert config.top_k == 4
assert config.expert_capacity == 256
class TestExpert:
"""Test individual expert module"""
def test_initialization(self):
"""Test expert initialization"""
expert = Expert(hidden_dim=768, intermediate_dim=3072)
assert expert.fc1.weight.shape == (3072, 768)
assert expert.fc2.weight.shape == (768, 3072)
def test_forward_pass(self):
"""Test expert forward pass"""
expert = Expert(hidden_dim=768, intermediate_dim=3072)
x = torch.randn(4, 16, 768)
output = expert(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
class TestTopKRouter:
"""Test Top-K routing mechanism"""
@pytest.fixture
def router(self):
return TopKRouter(hidden_dim=768, num_experts=8, top_k=2)
def test_initialization(self, router):
"""Test router initialization"""
assert router.num_experts == 8
assert router.top_k == 2
assert router.gate.weight.shape == (8, 768)
def test_forward_pass(self, router):
"""Test router forward pass"""
x = torch.randn(4, 16, 768)
expert_weights, expert_indices = router(x)
batch, seq, top_k = expert_weights.shape
assert expert_weights.shape == (4, 16, 2)
assert expert_indices.shape == (4, 16, 2)
# Check weights sum to 1
assert torch.allclose(expert_weights.sum(dim=-1), torch.ones(4, 16), atol=1e-5)
# Check indices are valid
assert (expert_indices >= 0).all()
assert (expert_indices < 8).all()
def test_routing_distribution(self, router):
"""Test that routing distributes across experts"""
x = torch.randn(32, 128, 768) # Larger batch
_, expert_indices = router(x)
# Count expert assignments
unique_experts = torch.unique(expert_indices)
# Should use multiple experts (not just one)
assert len(unique_experts) > 2
class TestMoELayer:
"""Test complete MoE layer"""
@pytest.fixture
def config(self):
return ExpertConfig(
num_knowledge_experts=8,
num_skill_experts=0, # Only knowledge experts for simplicity
top_k=2,
expert_capacity=128
)
def test_initialization(self, config):
"""Test MoE layer initialization"""
moe = MoELayer(hidden_dim=768, config=config)
assert len(moe.knowledge_experts) == 8
assert moe.router.num_experts == 8
def test_forward_pass(self, config):
"""Test MoE layer forward pass"""
moe = MoELayer(hidden_dim=768, config=config)
x = torch.randn(2, 16, 768)
output, aux_loss = moe(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
assert aux_loss >= 0 # Load balancing loss should be non-negative
def test_load_balancing(self, config):
"""Test that load balancing loss encourages distribution"""
moe = MoELayer(hidden_dim=768, config=config)
x = torch.randn(32, 128, 768) # Large batch
_, aux_loss = moe(x)
# Load balancing loss should be present
assert aux_loss > 0
assert aux_loss < 1.0 # Reasonable range
def test_expert_capacity(self, config):
"""Test expert capacity constraints"""
config.expert_capacity = 16 # Small capacity
moe = MoELayer(hidden_dim=768, config=config)
x = torch.randn(64, 128, 768) # Large input
output, _ = moe(x)
# Should still produce output without errors
assert output.shape == x.shape
class TestLoadBalancingLoss:
"""Test load balancing loss computation"""
def test_balanced_routing(self):
"""Test loss with perfectly balanced routing"""
# Perfectly uniform routing
expert_counts = torch.ones(8) * 100 # Each expert gets 100 tokens
expert_probs = torch.ones(8) * 0.125 # Each expert has 1/8 probability
loss = LoadBalancingLoss()(expert_counts, expert_probs, num_experts=8)
# Should be close to 0 for perfect balance
assert loss < 0.01
def test_imbalanced_routing(self):
"""Test loss with imbalanced routing"""
# Imbalanced: most tokens go to first expert
expert_counts = torch.tensor([700., 100., 100., 0., 0., 0., 0., 0.])
expert_probs = torch.tensor([0.7, 0.1, 0.1, 0.1, 0., 0., 0., 0.])
loss = LoadBalancingLoss()(expert_counts, expert_probs, num_experts=8)
# Should have significant loss
assert loss > 0.1
class TestMoEIntegration:
"""Integration tests for MoE system"""
def test_gradient_flow(self):
"""Test that gradients flow through MoE"""
config = ExpertConfig(num_knowledge_experts=4, top_k=2)
moe = MoELayer(hidden_dim=128, config=config)
x = torch.randn(2, 8, 128, requires_grad=True)
output, aux_loss = moe(x)
# Compute loss and backward
loss = output.sum() + aux_loss
loss.backward()
# Check gradients exist
assert x.grad is not None
assert moe.router.gate.weight.grad is not None
# Check gradients are reasonable
assert not torch.isnan(x.grad).any()
assert not torch.isinf(x.grad).any()
def test_different_batch_sizes(self):
"""Test MoE with different batch sizes"""
config = ExpertConfig(num_knowledge_experts=8, top_k=2)
moe = MoELayer(hidden_dim=256, config=config)
for batch_size in [1, 4, 16, 32]:
x = torch.randn(batch_size, 32, 256)
output, aux_loss = moe(x)
assert output.shape == x.shape
assert not torch.isnan(output).any()
if __name__ == "__main__":
pytest.main([__file__, "-v"])