#!/usr/bin/env python3 """Test script for Zenith-28B-p300 model""" import torch import unittest from pathlib import Path import sys sys.path.append(str(Path(__file__).parent)) from configs.zenith_config import get_28b_p300_config from models.zenith_model import ZenithForCausalLM, ZenithModel from data.advanced_tokenizer import AdvancedTokenizer class TestZenith28B(unittest.TestCase): """Test suite for Zenith-28B-p300 model.""" @classmethod def setUpClass(cls): """Set up test fixtures.""" cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cls.config = get_28b_p300_config() cls.config.vocab_size = 32000 # Test vocab size # Create small test model (reduced size for testing) test_config = get_28b_p300_config() test_config.hidden_size = 1024 # Smaller for testing test_config.num_layers = 12 test_config.num_heads = 16 test_config.intermediate_size = 4096 cls.model = ZenithModel(test_config) cls.model.to(cls.device) cls.model.eval() cls.tokenizer = AdvancedTokenizer(vocab_size=32000) def test_model_creation(self): """Test model can be created.""" self.assertIsNotNone(self.model) self.assertTrue(hasattr(self.model, 'transformer')) def test_forward_pass(self): """Test forward pass works.""" batch_size = 2 seq_len = 64 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) attention_mask = torch.ones(batch_size, seq_len).to(self.device) with torch.no_grad(): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) self.assertIsNotNone(outputs.logits) self.assertEqual(outputs.logits.shape[0], batch_size) self.assertEqual(outputs.logits.shape[1], seq_len) def test_p300_optimizations(self): """Test p300-specific optimizations are configured.""" self.assertTrue(self.config.use_tenstorrent_optimizations) self.assertEqual(self.config.tensor_parallel_size, 8) self.assertEqual(self.config.pipeline_parallel_size, 4) self.assertTrue(self.config.noc_optimization) self.assertTrue(self.config.use_ring_attention) self.assertEqual(self.config.max_seq_len, 32768) def test_moe_configuration(self): """Test MoE configuration for 28B.""" self.assertGreater(self.config.num_experts, 1) self.assertEqual(self.config.moe_top_k, 2) self.assertGreater(self.config.moe_load_balancing_weight, 0) def test_generation(self): """Test text generation.""" prompt = "Explain the concept of recursion." input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=50, temperature=0.7, do_sample=True ) generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True) self.assertIsInstance(generated, str) self.assertGreater(len(generated), len(prompt)) def test_loss_computation(self): """Test loss computation.""" batch_size = 2 seq_len = 64 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) attention_mask = torch.ones(batch_size, seq_len).to(self.device) outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) self.assertIsNotNone(outputs.loss) self.assertTrue(torch.isfinite(outputs.loss)) def test_ring_attention_parameters(self): """Test ring attention is properly configured.""" self.assertTrue(self.config.use_ring_attention) self.assertEqual(self.config.ring_attention_chunk_size, 8192) self.assertEqual(self.config.ring_attention_overlap, 2048) def test_eq_adapter(self): """Test EQ adapter presence.""" self.assertTrue(self.config.use_eq_adapter) self.assertEqual(self.config.eq_adapter_hidden_size, 64) def test_gradient_flow(self): """Test gradients flow correctly.""" batch_size = 1 seq_len = 32 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) self.model.train() outputs = self.model(input_ids=input_ids, labels=labels) loss = outputs.loss loss.backward() has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad) self.assertTrue(has_grad) def run_tests(): """Run all tests and report results.""" print("=" * 60) print("Zenith-28B-p300 Model Test Suite") print("=" * 60) loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestZenith28B) runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) print("\n" + "=" * 60) print("Test Summary:") print(f" Tests run: {result.testsRun}") print(f" Failures: {len(result.failures)}") print(f" Errors: {len(result.errors)}") print(f" Success: {result.wasSuccessful()}") print("=" * 60) return result.wasSuccessful() if __name__ == "__main__": success = run_tests() sys.exit(0 if success else 1)