WCNegentropy commited on
Commit
9202d01
ยท
verified ยท
1 Parent(s): 42dd387

๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/testing/code_test.py +141 -0
scripts/testing/code_test.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test BitTransformerLM on Code/Math Completion
4
+ """
5
+
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ sys.path.append('/data')
11
+ sys.path.append('/data/BitTransformerLM')
12
+
13
+ from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text
14
+
15
+ def load_model():
16
+ model = BitTransformerLM(
17
+ d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
18
+ max_seq_len=512, reversible=True, use_checkpoint=False,
19
+ use_autocast=False, use_act=True, act_threshold=0.9,
20
+ lambda_K=0.05, lambda_C=0.05, lambda_S=0.05
21
+ )
22
+
23
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
24
+ model.load_state_dict(checkpoint['model_state_dict'])
25
+ model.eval()
26
+
27
+ return model
28
+
29
+ def code_generate(model, prompt, max_chars=10):
30
+ """Generate code/math completion."""
31
+ print(f"\n๐Ÿงฎ Code completion for: '{prompt}'")
32
+
33
+ input_bits = text_to_bits(prompt)
34
+ generated_bits = input_bits.copy()
35
+
36
+ results = []
37
+
38
+ with torch.no_grad():
39
+ for char_idx in range(max_chars):
40
+ # Generate 9 bits for one character
41
+ char_bits = []
42
+
43
+ for bit_idx in range(9):
44
+ context = generated_bits + char_bits
45
+ context = context[-400:] if len(context) > 400 else context
46
+ context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
47
+
48
+ logits, telemetry = model(context_tensor)
49
+ next_bit_logits = logits[0, -1, :]
50
+
51
+ if bit_idx < 8: # Data bits
52
+ # Use different sampling for code (more deterministic)
53
+ temperature = 0.5 # Lower temperature for code
54
+ next_bit_logits = next_bit_logits / temperature
55
+
56
+ # Greedy sampling for first few characters to see most likely
57
+ if char_idx < 3:
58
+ next_bit = torch.argmax(next_bit_logits).item()
59
+ else:
60
+ probs = F.softmax(next_bit_logits, dim=-1)
61
+ next_bit = torch.multinomial(probs, 1).item()
62
+ else: # Parity bit
63
+ data_bits = char_bits[:8]
64
+ expected_parity = sum(data_bits) % 2
65
+ next_bit = expected_parity
66
+
67
+ char_bits.append(next_bit)
68
+
69
+ # Add character and try to decode
70
+ generated_bits.extend(char_bits)
71
+
72
+ # Decode this character
73
+ data_bits = char_bits[:8]
74
+ byte_val = sum(bit * (2**(7-i)) for i, bit in enumerate(data_bits))
75
+
76
+ if 32 <= byte_val <= 126: # Printable ASCII
77
+ char = chr(byte_val)
78
+ print(f" +'{char}' (confidence: {torch.max(F.softmax(next_bit_logits, dim=-1)).item():.3f})")
79
+ results.append(char)
80
+
81
+ # Stop on natural code endings
82
+ if char in ';{}()[]':
83
+ break
84
+ else:
85
+ print(f" +[{byte_val}] (non-printable)")
86
+ results.append('?')
87
+
88
+ completion = ''.join(results)
89
+ print(f"โœจ Result: '{prompt}' โ†’ '{prompt}{completion}'")
90
+
91
+ return completion
92
+
93
+ def main():
94
+ print("๐Ÿš€ BITRANSFORMERLM CODE/MATH COMPLETION TEST")
95
+ print("=" * 50)
96
+
97
+ model = load_model()
98
+ print("โœ… Model loaded!")
99
+
100
+ # Test structured prompts that might have learned patterns
101
+ test_cases = [
102
+ # Math equations
103
+ "2 + 2 =",
104
+ "1 + 1 =",
105
+ "5 * 3 =",
106
+ "10 / 2 =",
107
+
108
+ # Simple code patterns
109
+ "def hello():",
110
+ "if x ==",
111
+ "for i in",
112
+ "print(",
113
+ "return",
114
+
115
+ # Simple patterns
116
+ "a, b, c,",
117
+ "1, 2, 3,",
118
+ "red, blue,",
119
+
120
+ # HTML/markup
121
+ "<div>",
122
+ "function(",
123
+ "var x =",
124
+ ]
125
+
126
+ print(f"\n๐Ÿงฎ Testing {len(test_cases)} code/math patterns:")
127
+
128
+ for i, prompt in enumerate(test_cases):
129
+ print(f"\n--- Test {i+1}/{len(test_cases)} ---")
130
+ completion = code_generate(model, prompt, max_chars=6)
131
+
132
+ # Quick analysis
133
+ if any(c.isalnum() for c in completion):
134
+ print(" ๐Ÿ“ Contains alphanumeric - GOOD!")
135
+ if any(c in "0123456789" for c in completion):
136
+ print(" ๐Ÿ”ข Contains numbers - EXCELLENT!")
137
+ if any(c in "=(){}[];," for c in completion):
138
+ print(" ๐Ÿ’ป Contains code symbols - PROMISING!")
139
+
140
+ if __name__ == "__main__":
141
+ main()