|
|
| """Test script for Zenith-28B-p300 model"""
|
|
|
| import torch
|
| import unittest
|
| from pathlib import Path
|
| import sys
|
|
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from configs.zenith_config import get_28b_p300_config
|
| from models.zenith_model import ZenithForCausalLM, ZenithModel
|
| from data.advanced_tokenizer import AdvancedTokenizer
|
|
|
|
|
| class TestZenith28B(unittest.TestCase):
|
| """Test suite for Zenith-28B-p300 model."""
|
|
|
| @classmethod
|
| def setUpClass(cls):
|
| """Set up test fixtures."""
|
| cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| cls.config = get_28b_p300_config()
|
| cls.config.vocab_size = 32000
|
|
|
|
|
| test_config = get_28b_p300_config()
|
| test_config.hidden_size = 1024
|
| test_config.num_layers = 12
|
| test_config.num_heads = 16
|
| test_config.intermediate_size = 4096
|
|
|
| cls.model = ZenithModel(test_config)
|
| cls.model.to(cls.device)
|
| cls.model.eval()
|
|
|
| cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
|
|
|
| def test_model_creation(self):
|
| """Test model can be created."""
|
| self.assertIsNotNone(self.model)
|
| self.assertTrue(hasattr(self.model, 'transformer'))
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass works."""
|
| batch_size = 2
|
| seq_len = 64
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
| self.assertIsNotNone(outputs.logits)
|
| self.assertEqual(outputs.logits.shape[0], batch_size)
|
| self.assertEqual(outputs.logits.shape[1], seq_len)
|
|
|
| def test_p300_optimizations(self):
|
| """Test p300-specific optimizations are configured."""
|
| self.assertTrue(self.config.use_tenstorrent_optimizations)
|
| self.assertEqual(self.config.tensor_parallel_size, 8)
|
| self.assertEqual(self.config.pipeline_parallel_size, 4)
|
| self.assertTrue(self.config.noc_optimization)
|
| self.assertTrue(self.config.use_ring_attention)
|
| self.assertEqual(self.config.max_seq_len, 32768)
|
|
|
| def test_moe_configuration(self):
|
| """Test MoE configuration for 28B."""
|
| self.assertGreater(self.config.num_experts, 1)
|
| self.assertEqual(self.config.moe_top_k, 2)
|
| self.assertGreater(self.config.moe_load_balancing_weight, 0)
|
|
|
| def test_generation(self):
|
| """Test text generation."""
|
| prompt = "Explain the concept of recursion."
|
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model.generate(
|
| input_ids,
|
| max_new_tokens=50,
|
| temperature=0.7,
|
| do_sample=True
|
| )
|
|
|
| generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| self.assertIsInstance(generated, str)
|
| self.assertGreater(len(generated), len(prompt))
|
|
|
| def test_loss_computation(self):
|
| """Test loss computation."""
|
| batch_size = 2
|
| seq_len = 64
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
|
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| self.assertIsNotNone(outputs.loss)
|
| self.assertTrue(torch.isfinite(outputs.loss))
|
|
|
| def test_ring_attention_parameters(self):
|
| """Test ring attention is properly configured."""
|
| self.assertTrue(self.config.use_ring_attention)
|
| self.assertEqual(self.config.ring_attention_chunk_size, 8192)
|
| self.assertEqual(self.config.ring_attention_overlap, 2048)
|
|
|
| def test_eq_adapter(self):
|
| """Test EQ adapter presence."""
|
| self.assertTrue(self.config.use_eq_adapter)
|
| self.assertEqual(self.config.eq_adapter_hidden_size, 64)
|
|
|
| def test_gradient_flow(self):
|
| """Test gradients flow correctly."""
|
| batch_size = 1
|
| seq_len = 32
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
|
|
| self.model.train()
|
| outputs = self.model(input_ids=input_ids, labels=labels)
|
| loss = outputs.loss
|
| loss.backward()
|
|
|
| has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
|
| self.assertTrue(has_grad)
|
|
|
|
|
| def run_tests():
|
| """Run all tests and report results."""
|
| print("=" * 60)
|
| print("Zenith-28B-p300 Model Test Suite")
|
| print("=" * 60)
|
|
|
| loader = unittest.TestLoader()
|
| suite = loader.loadTestsFromTestCase(TestZenith28B)
|
|
|
| runner = unittest.TextTestRunner(verbosity=2)
|
| result = runner.run(suite)
|
|
|
| print("\n" + "=" * 60)
|
| print("Test Summary:")
|
| print(f" Tests run: {result.testsRun}")
|
| print(f" Failures: {len(result.failures)}")
|
| print(f" Errors: {len(result.errors)}")
|
| print(f" Success: {result.wasSuccessful()}")
|
| print("=" * 60)
|
|
|
| return result.wasSuccessful()
|
|
|
|
|
| if __name__ == "__main__":
|
| success = run_tests()
|
| sys.exit(0 if success else 1) |