| |
|
| | """Test script for Zenith-7B model"""
|
| |
|
| | import torch
|
| | import unittest
|
| | from pathlib import Path
|
| | import sys
|
| |
|
| | sys.path.append(str(Path(__file__).parent))
|
| |
|
| | from configs.zenith_config import get_7b_config
|
| | from models.zenith_model import ZenithForCausalLM, ZenithModel
|
| | from data.advanced_tokenizer import AdvancedTokenizer
|
| |
|
| |
|
| | class TestZenith7B(unittest.TestCase):
|
| | """Test suite for Zenith-7B model."""
|
| |
|
| | @classmethod
|
| | def setUpClass(cls):
|
| | """Set up test fixtures."""
|
| | cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | cls.config = get_7b_config()
|
| | cls.config.vocab_size = 32000
|
| |
|
| |
|
| | cls.model = ZenithModel(cls.config)
|
| | cls.model.to(cls.device)
|
| | cls.model.eval()
|
| |
|
| |
|
| | cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
|
| |
|
| | def test_model_creation(self):
|
| | """Test model can be created."""
|
| | self.assertIsNotNone(self.model)
|
| | self.assertTrue(hasattr(self.model, 'transformer'))
|
| |
|
| | def test_forward_pass(self):
|
| | """Test forward pass works."""
|
| | batch_size = 2
|
| | seq_len = 32
|
| | input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| | attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
| |
|
| | with torch.no_grad():
|
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| |
|
| | self.assertIsNotNone(outputs.logits)
|
| | self.assertEqual(outputs.logits.shape[0], batch_size)
|
| | self.assertEqual(outputs.logits.shape[1], seq_len)
|
| | self.assertEqual(outputs.logits.shape[2], self.config.vocab_size)
|
| |
|
| | def test_moe_activation(self):
|
| | """Test MoE layers are active when configured."""
|
| | if self.config.num_experts > 1:
|
| |
|
| | moe_layers = [m for m in self.model.modules() if hasattr(m, 'num_experts')]
|
| | self.assertGreater(len(moe_layers), 0)
|
| |
|
| | def test_generation(self):
|
| | """Test text generation."""
|
| | prompt = "Hello, world!"
|
| | input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
| |
|
| | with torch.no_grad():
|
| | outputs = self.model.generate(
|
| | input_ids,
|
| | max_new_tokens=20,
|
| | temperature=0.8,
|
| | do_sample=True
|
| | )
|
| |
|
| | generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| | self.assertIsInstance(generated, str)
|
| | self.assertGreater(len(generated), len(prompt))
|
| |
|
| | def test_loss_computation(self):
|
| | """Test loss computation with labels."""
|
| | batch_size = 2
|
| | seq_len = 32
|
| | input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| | labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| | attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
| |
|
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| | self.assertIsNotNone(outputs.loss)
|
| | self.assertTrue(torch.isfinite(outputs.loss))
|
| |
|
| | def test_multi_task_outputs(self):
|
| | """Test multi-task learning outputs when EQ adapter is enabled."""
|
| | if self.config.use_eq_adapter:
|
| | batch_size = 2
|
| | seq_len = 32
|
| | input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| | attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
| |
|
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| |
|
| |
|
| | self.assertTrue(hasattr(outputs, 'emotion_logits') or outputs.emotion_logits is not None)
|
| | self.assertTrue(hasattr(outputs, 'frustration_logits') or outputs.frustration_logits is not None)
|
| |
|
| | def test_gradient_flow(self):
|
| | """Test gradients flow correctly."""
|
| | batch_size = 1
|
| | seq_len = 16
|
| | input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| | labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| |
|
| | self.model.train()
|
| | outputs = self.model(input_ids=input_ids, labels=labels)
|
| | loss = outputs.loss
|
| | loss.backward()
|
| |
|
| |
|
| | has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
|
| | self.assertTrue(has_grad)
|
| |
|
| |
|
| | def run_tests():
|
| | """Run all tests and report results."""
|
| | print("=" * 60)
|
| | print("Zenith-7B Model Test Suite")
|
| | print("=" * 60)
|
| |
|
| |
|
| | loader = unittest.TestLoader()
|
| | suite = loader.loadTestsFromTestCase(TestZenith7B)
|
| |
|
| |
|
| | runner = unittest.TextTestRunner(verbosity=2)
|
| | result = runner.run(suite)
|
| |
|
| |
|
| | print("\n" + "=" * 60)
|
| | print("Test Summary:")
|
| | print(f" Tests run: {result.testsRun}")
|
| | print(f" Failures: {len(result.failures)}")
|
| | print(f" Errors: {len(result.errors)}")
|
| | print(f" Success: {result.wasSuccessful()}")
|
| | print("=" * 60)
|
| |
|
| | return result.wasSuccessful()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | success = run_tests()
|
| | sys.exit(0 if success else 1) |