#!/usr/bin/env python3 """Test script for Zenith-7B model""" import torch import unittest from pathlib import Path import sys sys.path.append(str(Path(__file__).parent)) from configs.zenith_config import get_7b_config from models.zenith_model import ZenithForCausalLM, ZenithModel from data.advanced_tokenizer import AdvancedTokenizer class TestZenith7B(unittest.TestCase): """Test suite for Zenith-7B model.""" @classmethod def setUpClass(cls): """Set up test fixtures.""" cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cls.config = get_7b_config() cls.config.vocab_size = 32000 # Test vocab size # Create small test model cls.model = ZenithModel(cls.config) cls.model.to(cls.device) cls.model.eval() # Create tokenizer cls.tokenizer = AdvancedTokenizer(vocab_size=32000) def test_model_creation(self): """Test model can be created.""" self.assertIsNotNone(self.model) self.assertTrue(hasattr(self.model, 'transformer')) def test_forward_pass(self): """Test forward pass works.""" batch_size = 2 seq_len = 32 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) attention_mask = torch.ones(batch_size, seq_len).to(self.device) with torch.no_grad(): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) self.assertIsNotNone(outputs.logits) self.assertEqual(outputs.logits.shape[0], batch_size) self.assertEqual(outputs.logits.shape[1], seq_len) self.assertEqual(outputs.logits.shape[2], self.config.vocab_size) def test_moe_activation(self): """Test MoE layers are active when configured.""" if self.config.num_experts > 1: # Check that MoE layers exist moe_layers = [m for m in self.model.modules() if hasattr(m, 'num_experts')] self.assertGreater(len(moe_layers), 0) def test_generation(self): """Test text generation.""" prompt = "Hello, world!" input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=20, temperature=0.8, do_sample=True ) generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True) self.assertIsInstance(generated, str) self.assertGreater(len(generated), len(prompt)) def test_loss_computation(self): """Test loss computation with labels.""" batch_size = 2 seq_len = 32 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) attention_mask = torch.ones(batch_size, seq_len).to(self.device) outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) self.assertIsNotNone(outputs.loss) self.assertTrue(torch.isfinite(outputs.loss)) def test_multi_task_outputs(self): """Test multi-task learning outputs when EQ adapter is enabled.""" if self.config.use_eq_adapter: batch_size = 2 seq_len = 32 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) attention_mask = torch.ones(batch_size, seq_len).to(self.device) outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) # Check for emotion and frustration logits if EQ adapter is enabled self.assertTrue(hasattr(outputs, 'emotion_logits') or outputs.emotion_logits is not None) self.assertTrue(hasattr(outputs, 'frustration_logits') or outputs.frustration_logits is not None) def test_gradient_flow(self): """Test gradients flow correctly.""" batch_size = 1 seq_len = 16 input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device) self.model.train() outputs = self.model(input_ids=input_ids, labels=labels) loss = outputs.loss loss.backward() # Check that gradients exist has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad) self.assertTrue(has_grad) def run_tests(): """Run all tests and report results.""" print("=" * 60) print("Zenith-7B Model Test Suite") print("=" * 60) # Create test suite loader = unittest.TestLoader() suite = loader.loadTestsFromTestCase(TestZenith7B) # Run tests runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) # Summary print("\n" + "=" * 60) print("Test Summary:") print(f" Tests run: {result.testsRun}") print(f" Failures: {len(result.failures)}") print(f" Errors: {len(result.errors)}") print(f" Success: {result.wasSuccessful()}") print("=" * 60) return result.wasSuccessful() if __name__ == "__main__": success = run_tests() sys.exit(0 if success else 1)