#!/usr/bin/env python3 """Test EQ engine implementation for all Zenith models.""" import sys import torch def test_model(model_name, config_module, model_module): """Test a specific model configuration.""" print(f"\n{'='*60}") print(f"Testing {model_name}...") print(f"{'='*60}") try: # Create config with all EQ features enabled config = config_module.ZenithConfig( use_eq_adapter=True, use_eq_attention_bias=True, use_eq_gated_ffn=True, use_eq_recurrence=True, eq_consistency_weight=0.02, eq_state_dim=256, num_layers=2, # Small for testing hidden_size=512 if hasattr(config_module.ZenithConfig, 'hidden_size') else 3072, num_heads=8, head_dim=64, intermediate_size=2048 if hasattr(config_module.ZenithConfig, 'intermediate_size') else 8192 ) # Create model model = model_module.ZenithModel(config) print(f"[OK] Model created successfully") print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") # Test forward pass batch_size = 1 seq_len = 8 input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) # Training mode to test consistency loss model.train() outputs = model(input_ids=input_ids, labels=input_ids) print(f"[OK] Forward pass successful") print(f" Logits shape: {outputs.logits.shape}") if outputs.loss is not None: print(f" Loss: {outputs.loss.item():.4f}") # Test inference mode model.eval() with torch.no_grad(): outputs = model(input_ids=input_ids) print(f"[OK] Inference successful") print(f" Logits shape: {outputs.logits.shape}") print(f"[SUCCESS] {model_name} EQ Engine is FULLY FUNCTIONAL") return True except Exception as e: print(f"[FAIL] {model_name} failed:") print(f" Error: {type(e).__name__}: {e}") import traceback traceback.print_exc() return False def main(): print("Testing EQ Engine Implementation for All Zenith Models") print("="*60) results = {} # Test 7B model try: from Zenith.V1_7B import configs as configs_7b from Zenith.V1_7B import modeling_zenith as model_7b results["7B"] = test_model("Zenith-7B", configs_7b, model_7b) except Exception as e: print(f"[FAIL] 7B model import error: {e}") results["7B"] = False # Test 28B model try: from Zenith.V1_Tenstorrent_Blackhole_p300_28B import configs as configs_28b from Zenith.V1_Tenstorrent_Blackhole_p300_28B import modeling_zenith as model_28b results["28B"] = test_model("Zenith-28B-p300", configs_28b, model_28b) except Exception as e: print(f"[FAIL] 28B model import error: {e}") results["28B"] = False # Test 32B model try: from Zenith.V1_Tenstorrent_Blackhole_p300_32B import configs as configs_32b from Zenith.V1_Tenstorrent_Blackhole_p300_32B import modeling_zenith as model_32b results["32B"] = test_model("Zenith-32B-p300", configs_32b, model_32b) except Exception as e: print(f"[FAIL] 32B model import error: {e}") results["32B"] = False # Test 70B model try: from Zenith.V1_Tenstorrent_Blackhole_p300_70B import configs as configs_70b from Zenith.V1_Tenstorrent_Blackhole_p300_70B import modeling_zenith as model_70b results["70B"] = test_model("Zenith-70B-p300", configs_70b, model_70b) except Exception as e: print(f"[FAIL] 70B model import error: {e}") results["70B"] = False # Summary print("\n" + "="*60) print("SUMMARY") print("="*60) for model_name, success in results.items(): status = "[PASS]" if success else "[FAIL]" print(f"{status} {model_name}") all_passed = all(results.values()) if all_passed: print("\n[SUCCESS] All models have functional EQ Engine implementation!") return 0 else: print("\n[WARNING] Some models failed. Please review errors above.") return 1 if __name__ == "__main__": sys.exit(main())