File size: 4,510 Bytes
1ea8a03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""Test EQ engine implementation for all Zenith models."""

import sys
import torch

def test_model(model_name, config_module, model_module):
    """Test a specific model configuration."""
    print(f"\n{'='*60}")
    print(f"Testing {model_name}...")
    print(f"{'='*60}")
    
    try:
        # Create config with all EQ features enabled
        config = config_module.ZenithConfig(
            use_eq_adapter=True,
            use_eq_attention_bias=True,
            use_eq_gated_ffn=True,
            use_eq_recurrence=True,
            eq_consistency_weight=0.02,
            eq_state_dim=256,
            num_layers=2,  # Small for testing
            hidden_size=512 if hasattr(config_module.ZenithConfig, 'hidden_size') else 3072,
            num_heads=8,
            head_dim=64,
            intermediate_size=2048 if hasattr(config_module.ZenithConfig, 'intermediate_size') else 8192
        )
        
        # Create model
        model = model_module.ZenithModel(config)
        print(f"[OK] Model created successfully")
        print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
        
        # Test forward pass
        batch_size = 1
        seq_len = 8
        input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
        
        # Training mode to test consistency loss
        model.train()
        outputs = model(input_ids=input_ids, labels=input_ids)
        
        print(f"[OK] Forward pass successful")
        print(f"  Logits shape: {outputs.logits.shape}")
        if outputs.loss is not None:
            print(f"  Loss: {outputs.loss.item():.4f}")
        
        # Test inference mode
        model.eval()
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            print(f"[OK] Inference successful")
            print(f"  Logits shape: {outputs.logits.shape}")
        
        print(f"[SUCCESS] {model_name} EQ Engine is FULLY FUNCTIONAL")
        return True
        
    except Exception as e:
        print(f"[FAIL] {model_name} failed:")
        print(f"  Error: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    print("Testing EQ Engine Implementation for All Zenith Models")
    print("="*60)
    
    results = {}
    
    # Test 7B model
    try:
        from Zenith.V1_7B import configs as configs_7b
        from Zenith.V1_7B import modeling_zenith as model_7b
        results["7B"] = test_model("Zenith-7B", configs_7b, model_7b)
    except Exception as e:
        print(f"[FAIL] 7B model import error: {e}")
        results["7B"] = False
    
    # Test 28B model
    try:
        from Zenith.V1_Tenstorrent_Blackhole_p300_28B import configs as configs_28b
        from Zenith.V1_Tenstorrent_Blackhole_p300_28B import modeling_zenith as model_28b
        results["28B"] = test_model("Zenith-28B-p300", configs_28b, model_28b)
    except Exception as e:
        print(f"[FAIL] 28B model import error: {e}")
        results["28B"] = False
    
    # Test 32B model
    try:
        from Zenith.V1_Tenstorrent_Blackhole_p300_32B import configs as configs_32b
        from Zenith.V1_Tenstorrent_Blackhole_p300_32B import modeling_zenith as model_32b
        results["32B"] = test_model("Zenith-32B-p300", configs_32b, model_32b)
    except Exception as e:
        print(f"[FAIL] 32B model import error: {e}")
        results["32B"] = False
    
    # Test 70B model
    try:
        from Zenith.V1_Tenstorrent_Blackhole_p300_70B import configs as configs_70b
        from Zenith.V1_Tenstorrent_Blackhole_p300_70B import modeling_zenith as model_70b
        results["70B"] = test_model("Zenith-70B-p300", configs_70b, model_70b)
    except Exception as e:
        print(f"[FAIL] 70B model import error: {e}")
        results["70B"] = False
    
    # Summary
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    for model_name, success in results.items():
        status = "[PASS]" if success else "[FAIL]"
        print(f"{status} {model_name}")
    
    all_passed = all(results.values())
    if all_passed:
        print("\n[SUCCESS] All models have functional EQ Engine implementation!")
        return 0
    else:
        print("\n[WARNING] Some models failed. Please review errors above.")
        return 1

if __name__ == "__main__":
    sys.exit(main())