Zenith-7b-V1 / test_all_models_eq.py
Zandy-Wandy's picture
Upload Zenith-7B model
1ea8a03 verified
#!/usr/bin/env python3
"""Test EQ engine implementation for all Zenith models."""
import sys
import torch
def test_model(model_name, config_module, model_module):
"""Test a specific model configuration."""
print(f"\n{'='*60}")
print(f"Testing {model_name}...")
print(f"{'='*60}")
try:
# Create config with all EQ features enabled
config = config_module.ZenithConfig(
use_eq_adapter=True,
use_eq_attention_bias=True,
use_eq_gated_ffn=True,
use_eq_recurrence=True,
eq_consistency_weight=0.02,
eq_state_dim=256,
num_layers=2, # Small for testing
hidden_size=512 if hasattr(config_module.ZenithConfig, 'hidden_size') else 3072,
num_heads=8,
head_dim=64,
intermediate_size=2048 if hasattr(config_module.ZenithConfig, 'intermediate_size') else 8192
)
# Create model
model = model_module.ZenithModel(config)
print(f"[OK] Model created successfully")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Test forward pass
batch_size = 1
seq_len = 8
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# Training mode to test consistency loss
model.train()
outputs = model(input_ids=input_ids, labels=input_ids)
print(f"[OK] Forward pass successful")
print(f" Logits shape: {outputs.logits.shape}")
if outputs.loss is not None:
print(f" Loss: {outputs.loss.item():.4f}")
# Test inference mode
model.eval()
with torch.no_grad():
outputs = model(input_ids=input_ids)
print(f"[OK] Inference successful")
print(f" Logits shape: {outputs.logits.shape}")
print(f"[SUCCESS] {model_name} EQ Engine is FULLY FUNCTIONAL")
return True
except Exception as e:
print(f"[FAIL] {model_name} failed:")
print(f" Error: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("Testing EQ Engine Implementation for All Zenith Models")
print("="*60)
results = {}
# Test 7B model
try:
from Zenith.V1_7B import configs as configs_7b
from Zenith.V1_7B import modeling_zenith as model_7b
results["7B"] = test_model("Zenith-7B", configs_7b, model_7b)
except Exception as e:
print(f"[FAIL] 7B model import error: {e}")
results["7B"] = False
# Test 28B model
try:
from Zenith.V1_Tenstorrent_Blackhole_p300_28B import configs as configs_28b
from Zenith.V1_Tenstorrent_Blackhole_p300_28B import modeling_zenith as model_28b
results["28B"] = test_model("Zenith-28B-p300", configs_28b, model_28b)
except Exception as e:
print(f"[FAIL] 28B model import error: {e}")
results["28B"] = False
# Test 32B model
try:
from Zenith.V1_Tenstorrent_Blackhole_p300_32B import configs as configs_32b
from Zenith.V1_Tenstorrent_Blackhole_p300_32B import modeling_zenith as model_32b
results["32B"] = test_model("Zenith-32B-p300", configs_32b, model_32b)
except Exception as e:
print(f"[FAIL] 32B model import error: {e}")
results["32B"] = False
# Test 70B model
try:
from Zenith.V1_Tenstorrent_Blackhole_p300_70B import configs as configs_70b
from Zenith.V1_Tenstorrent_Blackhole_p300_70B import modeling_zenith as model_70b
results["70B"] = test_model("Zenith-70B-p300", configs_70b, model_70b)
except Exception as e:
print(f"[FAIL] 70B model import error: {e}")
results["70B"] = False
# Summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
for model_name, success in results.items():
status = "[PASS]" if success else "[FAIL]"
print(f"{status} {model_name}")
all_passed = all(results.values())
if all_passed:
print("\n[SUCCESS] All models have functional EQ Engine implementation!")
return 0
else:
print("\n[WARNING] Some models failed. Please review errors above.")
return 1
if __name__ == "__main__":
sys.exit(main())