File size: 2,554 Bytes
1df0e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import unittest
import torch
import sys
from pathlib import Path

# Add project root to path
sys.path.append(str(Path(__file__).resolve().parent.parent))

from aetheris.modules.expert import Expert
from aetheris.modules.moe import SparseMoELayer
from aetheris.config import AetherisConfig

class TestOverflow(unittest.TestCase):
    def setUp(self):
        self.config = AetherisConfig(
            vocab_size=100,
            d_model=128,
            n_layer=2,
            num_experts=2,
            top_k=1,
            d_ff=512,  # Large enough to potentially cause issues
            ssm_d_state=16,
            ssm_expand=2,
            max_seq_len=64
        )
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def test_expert_overflow_protection(self):
        """Test if Expert handles large inputs without producing NaNs in float16"""
        expert = Expert(self.config.d_model, self.config.d_ff).to(self.device)
        # Manually cast weights to float16 to simulate mixed precision training environment
        expert.half()
        
        # Create a large input in float16 that would normally cause overflow in intermediate layers
        # The limit of float16 is ~65504. 
        # If w1 projects this up, it can easily exceed that.
        large_input = torch.ones(1, self.config.d_model, dtype=torch.float16).to(self.device) * 100.0
        
        # Force weights to be large to guarantee overflow if protection isn't working
        with torch.no_grad():
            expert.w1.weight.fill_(10.0)
            expert.w2.weight.fill_(0.1)

        # 100 * 10 = 1000. Sum over d_model(128) -> 128000. 
        # This summation happens in the matrix multiplication.
        # If the matmul internal accumulation is float16, it effectively overflows.
        
        output = expert(large_input)
        
        self.assertFalse(torch.isnan(output).any(), "Output contains NaNs")
        self.assertFalse(torch.isinf(output).any(), "Output contains Infs")

    def test_moe_accumulation_stability(self):
        """Test if MoE layer handles accumulation in float32"""
        moe = SparseMoELayer(self.config).to(self.device)
        moe.half()
        
        x = torch.randn(2, 10, self.config.d_model, dtype=torch.float16).to(self.device)
        
        # Pass through
        output, loss = moe(x)
        
        self.assertFalse(torch.isnan(output).any(), "MoE Output contains NaNs")
        self.assertEqual(output.dtype, torch.float16)

if __name__ == '__main__':
    unittest.main()