WCNegentropy commited on
Commit
189c75b
Β·
verified Β·
1 Parent(s): 7fe700d

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/examples/debug_generation.py +120 -0
scripts/examples/debug_generation.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug BitTransformerLM Generation
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ sys.path.append('/data')
11
+ sys.path.append('/data/BitTransformerLM')
12
+
13
+ from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
14
+
15
+ def load_model():
16
+ model = BitTransformerLM(
17
+ d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
18
+ max_seq_len=512, reversible=True, use_checkpoint=False,
19
+ use_autocast=False, use_act=True, act_threshold=0.9,
20
+ lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
21
+ )
22
+
23
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ model.eval()
26
+
27
+ return model, checkpoint['loss']
28
+
29
+ def generate_longer(model, prompt, num_chars=10):
30
+ """Generate longer sequences."""
31
+ print(f"\n🎯 Generating {num_chars} characters from: '{prompt}'")
32
+
33
+ input_bits = text_to_bits(prompt)
34
+ print(f"Input: {len(input_bits)} bits")
35
+
36
+ generated_bits = input_bits.copy()
37
+
38
+ with torch.no_grad():
39
+ # Generate num_chars * 9 bits (9 bits per character with parity)
40
+ for i in range(num_chars * 9):
41
+ # Use last 400 bits to stay within context
42
+ context_bits = generated_bits[-400:] if len(generated_bits) > 400 else generated_bits
43
+ context_tensor = torch.tensor(context_bits, dtype=torch.long).unsqueeze(0)
44
+
45
+ logits, telemetry = model(context_tensor)
46
+ next_bit_logits = logits[0, -1, :]
47
+
48
+ # Temperature sampling
49
+ temperature = 0.7
50
+ next_bit_logits = next_bit_logits / temperature
51
+ probs = F.softmax(next_bit_logits, dim=-1)
52
+ next_bit = torch.multinomial(probs, 1).item()
53
+
54
+ generated_bits.append(next_bit)
55
+
56
+ # Try to decode every 9 bits
57
+ if (i + 1) % 9 == 0:
58
+ generated_only = generated_bits[len(input_bits):]
59
+ try:
60
+ partial_text = bits_to_text(generated_only)
61
+ print(f" After {(i+1)//9} chars: '{partial_text}'")
62
+ except:
63
+ pass
64
+
65
+ # Final decode
66
+ generated_only = generated_bits[len(input_bits):]
67
+ try:
68
+ final_text = bits_to_text(generated_only)
69
+ print(f"✨ Final result: '{prompt}' + '{final_text}'")
70
+ return final_text
71
+ except Exception as e:
72
+ print(f"❌ Final decode failed: {e}")
73
+ print(f"Generated {len(generated_only)} bits: {generated_only[:50]}...")
74
+
75
+ # Try to decode in chunks
76
+ print("πŸ”§ Trying chunk decoding...")
77
+ for chunk_size in [9, 18, 27]: # 1, 2, 3 characters
78
+ if len(generated_only) >= chunk_size:
79
+ try:
80
+ chunk_text = bits_to_text(generated_only[:chunk_size])
81
+ print(f" First {chunk_size//9} chars: '{chunk_text}'")
82
+ except Exception as ce:
83
+ print(f" {chunk_size//9} chars failed: {ce}")
84
+
85
+ return None
86
+
87
+ def test_bit_encoding():
88
+ """Test the bit encoding/decoding functions."""
89
+ print("\nπŸ”§ Testing bit encoding/decoding...")
90
+
91
+ test_strings = ["A", "AB", "Hello", "Hi there!"]
92
+
93
+ for s in test_strings:
94
+ bits = text_to_bits(s)
95
+ try:
96
+ decoded = bits_to_text(bits)
97
+ status = "βœ…" if decoded == s else "❌"
98
+ print(f"{status} '{s}' -> {len(bits)} bits -> '{decoded}'")
99
+ except Exception as e:
100
+ print(f"❌ '{s}' -> {len(bits)} bits -> ERROR: {e}")
101
+
102
+ def main():
103
+ print("πŸš€ BITRANSFORMERLM GENERATION DEBUG")
104
+ print("=" * 50)
105
+
106
+ # Test encoding first
107
+ test_bit_encoding()
108
+
109
+ # Load model
110
+ model, loss = load_model()
111
+ print(f"\nβœ… Model loaded! Loss: {loss:.6f}")
112
+
113
+ # Test generation
114
+ prompts = ["Hello", "Hi", "A", "The"]
115
+
116
+ for prompt in prompts:
117
+ generate_longer(model, prompt, num_chars=3)
118
+
119
+ if __name__ == "__main__":
120
+ main()