Zenith-28b-p300-V1 / test_model.py
Zandy-Wandy's picture
Upload Zenith-28b-V1-Tenstorrent-Blackhole-p300 model
8944ef7 verified
#!/usr/bin/env python3
"""Test script for Zenith-28B-p300 model"""
import torch
import unittest
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_28b_p300_config
from models.zenith_model import ZenithForCausalLM, ZenithModel
from data.advanced_tokenizer import AdvancedTokenizer
class TestZenith28B(unittest.TestCase):
"""Test suite for Zenith-28B-p300 model."""
@classmethod
def setUpClass(cls):
"""Set up test fixtures."""
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.config = get_28b_p300_config()
cls.config.vocab_size = 32000 # Test vocab size
# Create small test model (reduced size for testing)
test_config = get_28b_p300_config()
test_config.hidden_size = 1024 # Smaller for testing
test_config.num_layers = 12
test_config.num_heads = 16
test_config.intermediate_size = 4096
cls.model = ZenithModel(test_config)
cls.model.to(cls.device)
cls.model.eval()
cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
def test_model_creation(self):
"""Test model can be created."""
self.assertIsNotNone(self.model)
self.assertTrue(hasattr(self.model, 'transformer'))
def test_forward_pass(self):
"""Test forward pass works."""
batch_size = 2
seq_len = 64
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
self.assertIsNotNone(outputs.logits)
self.assertEqual(outputs.logits.shape[0], batch_size)
self.assertEqual(outputs.logits.shape[1], seq_len)
def test_p300_optimizations(self):
"""Test p300-specific optimizations are configured."""
self.assertTrue(self.config.use_tenstorrent_optimizations)
self.assertEqual(self.config.tensor_parallel_size, 8)
self.assertEqual(self.config.pipeline_parallel_size, 4)
self.assertTrue(self.config.noc_optimization)
self.assertTrue(self.config.use_ring_attention)
self.assertEqual(self.config.max_seq_len, 32768)
def test_moe_configuration(self):
"""Test MoE configuration for 28B."""
self.assertGreater(self.config.num_experts, 1)
self.assertEqual(self.config.moe_top_k, 2)
self.assertGreater(self.config.moe_load_balancing_weight, 0)
def test_generation(self):
"""Test text generation."""
prompt = "Explain the concept of recursion."
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=50,
temperature=0.7,
do_sample=True
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
self.assertIsInstance(generated, str)
self.assertGreater(len(generated), len(prompt))
def test_loss_computation(self):
"""Test loss computation."""
batch_size = 2
seq_len = 64
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
self.assertIsNotNone(outputs.loss)
self.assertTrue(torch.isfinite(outputs.loss))
def test_ring_attention_parameters(self):
"""Test ring attention is properly configured."""
self.assertTrue(self.config.use_ring_attention)
self.assertEqual(self.config.ring_attention_chunk_size, 8192)
self.assertEqual(self.config.ring_attention_overlap, 2048)
def test_eq_adapter(self):
"""Test EQ adapter presence."""
self.assertTrue(self.config.use_eq_adapter)
self.assertEqual(self.config.eq_adapter_hidden_size, 64)
def test_gradient_flow(self):
"""Test gradients flow correctly."""
batch_size = 1
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
self.model.train()
outputs = self.model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
self.assertTrue(has_grad)
def run_tests():
"""Run all tests and report results."""
print("=" * 60)
print("Zenith-28B-p300 Model Test Suite")
print("=" * 60)
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestZenith28B)
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
print("\n" + "=" * 60)
print("Test Summary:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
print(f" Success: {result.wasSuccessful()}")
print("=" * 60)
return result.wasSuccessful()
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1)