Aetheris-Inference / tests /test_overflow.py
Pomilon
Deploy Aetheris to HF Space
1df0e33
import unittest
import torch
import sys
from pathlib import Path
# Add project root to path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from aetheris.modules.expert import Expert
from aetheris.modules.moe import SparseMoELayer
from aetheris.config import AetherisConfig
class TestOverflow(unittest.TestCase):
def setUp(self):
self.config = AetherisConfig(
vocab_size=100,
d_model=128,
n_layer=2,
num_experts=2,
top_k=1,
d_ff=512, # Large enough to potentially cause issues
ssm_d_state=16,
ssm_expand=2,
max_seq_len=64
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def test_expert_overflow_protection(self):
"""Test if Expert handles large inputs without producing NaNs in float16"""
expert = Expert(self.config.d_model, self.config.d_ff).to(self.device)
# Manually cast weights to float16 to simulate mixed precision training environment
expert.half()
# Create a large input in float16 that would normally cause overflow in intermediate layers
# The limit of float16 is ~65504.
# If w1 projects this up, it can easily exceed that.
large_input = torch.ones(1, self.config.d_model, dtype=torch.float16).to(self.device) * 100.0
# Force weights to be large to guarantee overflow if protection isn't working
with torch.no_grad():
expert.w1.weight.fill_(10.0)
expert.w2.weight.fill_(0.1)
# 100 * 10 = 1000. Sum over d_model(128) -> 128000.
# This summation happens in the matrix multiplication.
# If the matmul internal accumulation is float16, it effectively overflows.
output = expert(large_input)
self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")
self.assertFalse(torch.isinf(output).any(), "Output contains Infs")
def test_moe_accumulation_stability(self):
"""Test if MoE layer handles accumulation in float32"""
moe = SparseMoELayer(self.config).to(self.device)
moe.half()
x = torch.randn(2, 10, self.config.d_model, dtype=torch.float16).to(self.device)
# Pass through
output, loss = moe(x)
self.assertFalse(torch.isnan(output).any(), "MoE Output contains NaNs")
self.assertEqual(output.dtype, torch.float16)
if __name__ == '__main__':
unittest.main()