File size: 3,928 Bytes
f64dfb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
"""
Raw BitTransformerLM Generation - Bypass Parity
"""

import sys
import torch
import torch.nn.functional as F

sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import BitTransformerLM, text_to_bits

def load_model():
    model = BitTransformerLM(
        d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
        max_seq_len=512, reversible=True, use_checkpoint=False,
        use_autocast=False, use_act=True, act_threshold=0.9,
        lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
    )
    
    checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint['loss']

def bits_to_ascii_raw(bits):
    """Convert bits to ASCII without parity checking."""
    if len(bits) % 8 != 0:
        # Pad to multiple of 8
        bits = bits + [0] * (8 - len(bits) % 8)
    
    chars = []
    for i in range(0, len(bits), 8):
        byte_bits = bits[i:i+8]
        byte_value = sum(bit * (2 ** (7-j)) for j, bit in enumerate(byte_bits))
        
        # Only accept printable ASCII
        if 32 <= byte_value <= 126:
            chars.append(chr(byte_value))
        elif byte_value == 10:  # newline
            chars.append('\n')
        elif byte_value == 13:  # carriage return
            chars.append('\r')
        else:
            chars.append('�')  # replacement for non-printable
    
    return ''.join(chars)

def generate_raw(model, prompt, num_bits=72):  # 9 bytes worth
    """Generate bits and decode as raw ASCII."""
    print(f"\n🎯 Generating {num_bits} bits from: '{prompt}'")
    
    input_bits = text_to_bits(prompt)
    print(f"Input: {len(input_bits)} bits")
    
    generated_bits = input_bits.copy()
    
    with torch.no_grad():
        for i in range(num_bits):
            # Context window
            context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
            context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
            
            logits, telemetry = model(context_tensor)
            next_bit_logits = logits[0, -1, :]
            
            # Lower temperature for more coherent output
            temperature = 0.6
            next_bit_logits = next_bit_logits / temperature
            probs = F.softmax(next_bit_logits, dim=-1)
            next_bit = torch.multinomial(probs, 1).item()
            
            generated_bits.append(next_bit)
            
            # Progress update
            if (i + 1) % 16 == 0:  # Every 2 bytes
                generated_only = generated_bits[len(input_bits):]
                partial_text = bits_to_ascii_raw(generated_only)
                print(f"  {i+1:2d} bits: '{partial_text}'")
    
    # Final decode
    generated_only = generated_bits[len(input_bits):]
    final_text = bits_to_ascii_raw(generated_only)
    
    print(f"✨ Final: '{prompt}' + '{final_text}'")
    
    if telemetry:
        k = telemetry.get('negentropy_logits', 0)
        c = telemetry.get('lz_complexity_logits', 0) 
        s = telemetry.get('symbiosis_score', 0)
        if torch.is_tensor(k): k = k.mean().item()
        if torch.is_tensor(c): c = c.mean().item()
        if torch.is_tensor(s): s = s.mean().item()
        print(f"📊 Telemetry: K={k:.3f}, C={c:.3f}, S={s:.3f}")
    
    return final_text

def main():
    print("🚀 RAW BITRANSFORMERLM GENERATION")
    print("=" * 40)
    
    model, loss = load_model()
    print(f"✅ Model loaded! Loss: {loss:.6f}")
    
    prompts = [
        "Hello",
        "Hi there",
        "What",
        "The weather",
        "AI:",
        "Q: What is your name?\nA:"
    ]
    
    for prompt in prompts:
        generate_raw(model, prompt, num_bits=64)  # 8 characters worth

if __name__ == "__main__":
    main()