Spaces:
Sleeping
Sleeping
File size: 1,676 Bytes
1df0e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import unittest
import torch
import sys
from pathlib import Path
# Add project root to path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from aetheris.config import AetherisConfig
from aetheris.model import HybridMambaMoE
class TestHybridMambaMoE(unittest.TestCase):
def setUp(self):
self.config = AetherisConfig(
vocab_size=100,
d_model=32,
n_layer=4,
num_experts=2,
top_k=1,
d_ff=64,
ssm_d_state=8,
ssm_expand=2,
max_seq_len=64
)
self.model = HybridMambaMoE(self.config)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def test_forward_pass(self):
batch_size = 2
seq_len = 16
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
output = self.model(input_ids)
self.assertIn('logits', output)
self.assertEqual(output['logits'].shape, (batch_size, seq_len, self.config.vocab_size))
def test_forward_pass_with_labels(self):
batch_size = 2
seq_len = 16
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = input_ids.clone()
output = self.model(input_ids, labels=labels)
self.assertIn('loss', output)
self.assertIn('ce_loss', output)
self.assertIn('aux_loss', output)
self.assertIn('logits', output)
self.assertTrue(output['loss'] > 0)
if __name__ == '__main__':
unittest.main()
|