File size: 5,525 Bytes
8d18b7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""Test script for Zenith-7B model"""
import torch
import unittest
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_7b_config
from models.zenith_model import ZenithForCausalLM, ZenithModel
from data.advanced_tokenizer import AdvancedTokenizer
class TestZenith7B(unittest.TestCase):
"""Test suite for Zenith-7B model."""
@classmethod
def setUpClass(cls):
"""Set up test fixtures."""
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.config = get_7b_config()
cls.config.vocab_size = 32000 # Test vocab size
# Create small test model
cls.model = ZenithModel(cls.config)
cls.model.to(cls.device)
cls.model.eval()
# Create tokenizer
cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
def test_model_creation(self):
"""Test model can be created."""
self.assertIsNotNone(self.model)
self.assertTrue(hasattr(self.model, 'transformer'))
def test_forward_pass(self):
"""Test forward pass works."""
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
self.assertIsNotNone(outputs.logits)
self.assertEqual(outputs.logits.shape[0], batch_size)
self.assertEqual(outputs.logits.shape[1], seq_len)
self.assertEqual(outputs.logits.shape[2], self.config.vocab_size)
def test_moe_activation(self):
"""Test MoE layers are active when configured."""
if self.config.num_experts > 1:
# Check that MoE layers exist
moe_layers = [m for m in self.model.modules() if hasattr(m, 'num_experts')]
self.assertGreater(len(moe_layers), 0)
def test_generation(self):
"""Test text generation."""
prompt = "Hello, world!"
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=20,
temperature=0.8,
do_sample=True
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
self.assertIsInstance(generated, str)
self.assertGreater(len(generated), len(prompt))
def test_loss_computation(self):
"""Test loss computation with labels."""
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
self.assertIsNotNone(outputs.loss)
self.assertTrue(torch.isfinite(outputs.loss))
def test_multi_task_outputs(self):
"""Test multi-task learning outputs when EQ adapter is enabled."""
if self.config.use_eq_adapter:
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
# Check for emotion and frustration logits if EQ adapter is enabled
self.assertTrue(hasattr(outputs, 'emotion_logits') or outputs.emotion_logits is not None)
self.assertTrue(hasattr(outputs, 'frustration_logits') or outputs.frustration_logits is not None)
def test_gradient_flow(self):
"""Test gradients flow correctly."""
batch_size = 1
seq_len = 16
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
self.model.train()
outputs = self.model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
# Check that gradients exist
has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
self.assertTrue(has_grad)
def run_tests():
"""Run all tests and report results."""
print("=" * 60)
print("Zenith-7B Model Test Suite")
print("=" * 60)
# Create test suite
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestZenith7B)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Summary
print("\n" + "=" * 60)
print("Test Summary:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
print(f" Success: {result.wasSuccessful()}")
print("=" * 60)
return result.wasSuccessful()
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1) |