Spaces:
Sleeping
Sleeping
| import unittest | |
| import torch | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| from aetheris.modules.expert import Expert | |
| from aetheris.modules.moe import SparseMoELayer | |
| from aetheris.config import AetherisConfig | |
| class TestOverflow(unittest.TestCase): | |
| def setUp(self): | |
| self.config = AetherisConfig( | |
| vocab_size=100, | |
| d_model=128, | |
| n_layer=2, | |
| num_experts=2, | |
| top_k=1, | |
| d_ff=512, # Large enough to potentially cause issues | |
| ssm_d_state=16, | |
| ssm_expand=2, | |
| max_seq_len=64 | |
| ) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def test_expert_overflow_protection(self): | |
| """Test if Expert handles large inputs without producing NaNs in float16""" | |
| expert = Expert(self.config.d_model, self.config.d_ff).to(self.device) | |
| # Manually cast weights to float16 to simulate mixed precision training environment | |
| expert.half() | |
| # Create a large input in float16 that would normally cause overflow in intermediate layers | |
| # The limit of float16 is ~65504. | |
| # If w1 projects this up, it can easily exceed that. | |
| large_input = torch.ones(1, self.config.d_model, dtype=torch.float16).to(self.device) * 100.0 | |
| # Force weights to be large to guarantee overflow if protection isn't working | |
| with torch.no_grad(): | |
| expert.w1.weight.fill_(10.0) | |
| expert.w2.weight.fill_(0.1) | |
| # 100 * 10 = 1000. Sum over d_model(128) -> 128000. | |
| # This summation happens in the matrix multiplication. | |
| # If the matmul internal accumulation is float16, it effectively overflows. | |
| output = expert(large_input) | |
| self.assertFalse(torch.isnan(output).any(), "Output contains NaNs") | |
| self.assertFalse(torch.isinf(output).any(), "Output contains Infs") | |
| def test_moe_accumulation_stability(self): | |
| """Test if MoE layer handles accumulation in float32""" | |
| moe = SparseMoELayer(self.config).to(self.device) | |
| moe.half() | |
| x = torch.randn(2, 10, self.config.d_model, dtype=torch.float16).to(self.device) | |
| # Pass through | |
| output, loss = moe(x) | |
| self.assertFalse(torch.isnan(output).any(), "MoE Output contains NaNs") | |
| self.assertEqual(output.dtype, torch.float16) | |
| if __name__ == '__main__': | |
| unittest.main() | |