|
|
| """Test EQ engine implementation for all Zenith models."""
|
|
|
| import sys
|
| import torch
|
|
|
| def test_model(model_name, config_module, model_module):
|
| """Test a specific model configuration."""
|
| print(f"\n{'='*60}")
|
| print(f"Testing {model_name}...")
|
| print(f"{'='*60}")
|
|
|
| try:
|
|
|
| config = config_module.ZenithConfig(
|
| use_eq_adapter=True,
|
| use_eq_attention_bias=True,
|
| use_eq_gated_ffn=True,
|
| use_eq_recurrence=True,
|
| eq_consistency_weight=0.02,
|
| eq_state_dim=256,
|
| num_layers=2,
|
| hidden_size=512 if hasattr(config_module.ZenithConfig, 'hidden_size') else 3072,
|
| num_heads=8,
|
| head_dim=64,
|
| intermediate_size=2048 if hasattr(config_module.ZenithConfig, 'intermediate_size') else 8192
|
| )
|
|
|
|
|
| model = model_module.ZenithModel(config)
|
| print(f"[OK] Model created successfully")
|
| print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
|
| batch_size = 1
|
| seq_len = 8
|
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|
|
|
|
| model.train()
|
| outputs = model(input_ids=input_ids, labels=input_ids)
|
|
|
| print(f"[OK] Forward pass successful")
|
| print(f" Logits shape: {outputs.logits.shape}")
|
| if outputs.loss is not None:
|
| print(f" Loss: {outputs.loss.item():.4f}")
|
|
|
|
|
| model.eval()
|
| with torch.no_grad():
|
| outputs = model(input_ids=input_ids)
|
| print(f"[OK] Inference successful")
|
| print(f" Logits shape: {outputs.logits.shape}")
|
|
|
| print(f"[SUCCESS] {model_name} EQ Engine is FULLY FUNCTIONAL")
|
| return True
|
|
|
| except Exception as e:
|
| print(f"[FAIL] {model_name} failed:")
|
| print(f" Error: {type(e).__name__}: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return False
|
|
|
| def main():
|
| print("Testing EQ Engine Implementation for All Zenith Models")
|
| print("="*60)
|
|
|
| results = {}
|
|
|
|
|
| try:
|
| from Zenith.V1_7B import configs as configs_7b
|
| from Zenith.V1_7B import modeling_zenith as model_7b
|
| results["7B"] = test_model("Zenith-7B", configs_7b, model_7b)
|
| except Exception as e:
|
| print(f"[FAIL] 7B model import error: {e}")
|
| results["7B"] = False
|
|
|
|
|
| try:
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_28B import configs as configs_28b
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_28B import modeling_zenith as model_28b
|
| results["28B"] = test_model("Zenith-28B-p300", configs_28b, model_28b)
|
| except Exception as e:
|
| print(f"[FAIL] 28B model import error: {e}")
|
| results["28B"] = False
|
|
|
|
|
| try:
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_32B import configs as configs_32b
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_32B import modeling_zenith as model_32b
|
| results["32B"] = test_model("Zenith-32B-p300", configs_32b, model_32b)
|
| except Exception as e:
|
| print(f"[FAIL] 32B model import error: {e}")
|
| results["32B"] = False
|
|
|
|
|
| try:
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_70B import configs as configs_70b
|
| from Zenith.V1_Tenstorrent_Blackhole_p300_70B import modeling_zenith as model_70b
|
| results["70B"] = test_model("Zenith-70B-p300", configs_70b, model_70b)
|
| except Exception as e:
|
| print(f"[FAIL] 70B model import error: {e}")
|
| results["70B"] = False
|
|
|
|
|
| print("\n" + "="*60)
|
| print("SUMMARY")
|
| print("="*60)
|
| for model_name, success in results.items():
|
| status = "[PASS]" if success else "[FAIL]"
|
| print(f"{status} {model_name}")
|
|
|
| all_passed = all(results.values())
|
| if all_passed:
|
| print("\n[SUCCESS] All models have functional EQ Engine implementation!")
|
| return 0
|
| else:
|
| print("\n[WARNING] Some models failed. Please review errors above.")
|
| return 1
|
|
|
| if __name__ == "__main__":
|
| sys.exit(main()) |