File size: 4,208 Bytes
42dd387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
"""
Simple BitTransformerLM Test - No Interactive Input
"""

import sys
import torch
import torch.nn.functional as F

# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text

def test_breakthrough_model():
    """Simple test of the breakthrough model."""
    print("πŸš€ Loading breakthrough BitTransformerLM...")
    
    # Create model with exact config
    model = BitTransformerLM(
        d_model=512,
        nhead=16,
        num_layers=8,
        dim_feedforward=1024,
        max_seq_len=512,
        reversible=True,
        use_checkpoint=False,  # Disable for inference
        use_autocast=False,    # Disable for inference
        use_act=True,
        act_threshold=0.9,
        lambda_K=0.05,
        lambda_C=0.05,
        lambda_S=0.05
    )
    
    # Load checkpoint
    checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"βœ… Model loaded! Loss: {checkpoint['loss']:.6f}")
    
    # Simple test prompts
    prompts = [
        "Hello",
        "Hi there",
        "What is your name?",
        "The weather is"
    ]
    
    for prompt in prompts:
        print(f"\nπŸ€– Testing: '{prompt}'")
        
        # Convert to bits
        input_bits = text_to_bits(prompt)
        input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0)
        
        print(f"πŸ“ Input: {len(input_bits)} bits")
        
        with torch.no_grad():
            try:
                # Forward pass
                logits, telemetry = model(input_tensor)
                
                # Get next bit probabilities
                next_probs = F.softmax(logits[0, -1, :], dim=-1)
                
                print(f"🎯 Next bit probs: [0]={next_probs[0]:.3f}, [1]={next_probs[1]:.3f}")
                
                if telemetry:
                    k_val = telemetry.get('negentropy_logits', 0)
                    c_val = telemetry.get('lz_complexity_logits', 0) 
                    s_val = telemetry.get('symbiosis_score', 0)
                    
                    # Convert to scalar if tensor
                    if torch.is_tensor(k_val):
                        k_val = k_val.mean().item()
                    if torch.is_tensor(c_val):
                        c_val = c_val.mean().item()
                    if torch.is_tensor(s_val):
                        s_val = s_val.mean().item()
                        
                    print(f"πŸ“Š Telemetry: K={k_val:.3f}, C={c_val:.3f}, S={s_val:.3f}")
                
                # Try simple generation (just 18 bits = 2 characters)
                generated_bits = input_bits.copy()
                
                for i in range(18):  # 2 characters worth
                    current_tensor = torch.tensor(generated_bits, dtype=torch.long).unsqueeze(0)
                    if current_tensor.size(1) > 500:  # Truncate if too long
                        current_tensor = current_tensor[:, -500:]
                        
                    logits, _ = model(current_tensor)
                    next_bit_logits = logits[0, -1, :]
                    
                    # Sample with temperature
                    next_bit_logits = next_bit_logits / 0.8
                    probs = F.softmax(next_bit_logits, dim=-1)
                    next_bit = torch.multinomial(probs, 1).item()
                    
                    generated_bits.append(next_bit)
                
                # Try to decode
                generated_only = generated_bits[len(input_bits):]
                try:
                    generated_text = bits_to_text(generated_only)
                    print(f"✨ Generated: '{generated_text}'")
                except Exception as e:
                    print(f"πŸ”§ Decode failed: {e}")
                    print(f"Raw bits: {generated_only}")
                
            except Exception as e:
                print(f"❌ Model error: {e}")

if __name__ == "__main__":
    test_breakthrough_model()