WCNegentropy commited on
Commit
72bc506
Β·
verified Β·
1 Parent(s): bde6dbb

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
scripts/testing/enhanced_generation_test.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enhanced BitTransformerLM Generation Testing
4
+ =============================================
5
+
6
+ Test the promising generation improvements:
7
+ 1. Autoregressive generation with automatic parity correction
8
+ 2. Longer sequence generation (50, 100, 200+ characters)
9
+ 3. Optimized diffusion parameters (50+ steps)
10
+ 4. Direct comparison between generation methods
11
+
12
+ Goal: See if we can get from "barely-contextual gibberish" to actual language!
13
+ """
14
+
15
+ import sys
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from datetime import datetime
19
+
20
+ sys.path.append('/data')
21
+ sys.path.append('/data/BitTransformerLM')
22
+
23
+ from bit_transformer import (
24
+ BitTransformerLM,
25
+ text_to_bits,
26
+ bits_to_text,
27
+ diffusion_inference,
28
+ set_dropout,
29
+ enforce_parity
30
+ )
31
+
32
+ def load_full_attention_model():
33
+ """Load the full attention BitTransformerLM model."""
34
+ print("πŸš€ Loading Full Attention BitTransformerLM for enhanced generation testing...")
35
+
36
+ model = BitTransformerLM(
37
+ d_model=512, nhead=16, num_layers=8, dim_feedforward=1024,
38
+ max_seq_len=512, reversible=True, use_checkpoint=False,
39
+ use_autocast=False, use_act=True, act_threshold=0.9,
40
+ lambda_K=0.05, lambda_C=0.05, lambda_S=0.05,
41
+ chunk_size=None, overlap=0, full_attn_logging=True
42
+ )
43
+
44
+ checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt'
45
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
46
+ model.load_state_dict(checkpoint['model_state_dict'])
47
+ model.eval()
48
+ set_dropout(model, 0.0)
49
+
50
+ epoch = checkpoint.get('epoch', 'unknown')
51
+ loss = checkpoint.get('loss', 'unknown')
52
+ print(f"βœ… Model loaded! Epoch: {epoch}, Loss: {loss}")
53
+
54
+ return model
55
+
56
+ def autoregressive_generate_with_parity_correction(model, prompt, max_new_chars=20, temperature=0.7):
57
+ """
58
+ Autoregressive generation with automatic parity correction.
59
+ This should solve the parity check failure issue that blocked autoregressive evaluation.
60
+ """
61
+ print(f"\nπŸ”„ Autoregressive generation with parity correction:")
62
+ print(f" Prompt: '{prompt}' β†’ generating {max_new_chars} characters...")
63
+
64
+ # Convert prompt to bits
65
+ input_bits = text_to_bits(prompt)
66
+ generated_bits = input_bits.copy()
67
+
68
+ with torch.no_grad():
69
+ for char_idx in range(max_new_chars):
70
+ char_bits = []
71
+
72
+ # Generate 8 data bits + 1 parity bit per character
73
+ for bit_idx in range(9):
74
+ # Use last 400 bits as context
75
+ context = generated_bits + char_bits
76
+ context = context[-400:] if len(context) > 400 else context
77
+ context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
78
+
79
+ # Get next bit prediction
80
+ logits, telemetry = model(context_tensor, causal=True)
81
+ next_bit_logits = logits[0, -1, :]
82
+
83
+ if bit_idx < 8: # Data bits
84
+ # Apply temperature for controlled randomness
85
+ if temperature > 0:
86
+ next_bit_logits = next_bit_logits / temperature
87
+ probs = F.softmax(next_bit_logits, dim=-1)
88
+ next_bit = torch.multinomial(probs, 1).item()
89
+ else:
90
+ next_bit = torch.argmax(next_bit_logits).item()
91
+ else: # Parity bit - calculate correct parity
92
+ data_bits = char_bits[:8]
93
+ expected_parity = sum(data_bits) % 2
94
+ next_bit = expected_parity
95
+
96
+ char_bits.append(next_bit)
97
+
98
+ # Add character to generated sequence
99
+ generated_bits.extend(char_bits)
100
+
101
+ # Extract only the new bits (excluding prompt)
102
+ new_bits = generated_bits[len(input_bits):]
103
+
104
+ # Apply additional parity correction if needed
105
+ new_bits_tensor = torch.tensor(new_bits, dtype=torch.long)
106
+ corrected_bits_tensor, parity_corrections = enforce_parity(new_bits_tensor)
107
+ corrected_bits = corrected_bits_tensor.tolist()
108
+
109
+ try:
110
+ # Decode new text
111
+ decoded_text = bits_to_text(corrected_bits)
112
+ full_result = prompt + decoded_text
113
+ print(f" βœ… SUCCESS: '{full_result}'")
114
+ return {
115
+ 'success': True,
116
+ 'full_text': full_result,
117
+ 'new_text': decoded_text,
118
+ 'bits_generated': len(new_bits),
119
+ 'parity_corrections': parity_corrections
120
+ }
121
+ except Exception as e:
122
+ print(f" ❌ DECODE FAILED: {e}")
123
+ return {
124
+ 'success': False,
125
+ 'error': str(e),
126
+ 'bits_generated': len(new_bits)
127
+ }
128
+
129
+ def long_diffusion_generation(model, prompt, target_chars, steps=50):
130
+ """
131
+ Generate longer sequences with optimized diffusion parameters.
132
+ """
133
+ print(f"\n🌊 Long diffusion generation:")
134
+ print(f" Prompt: '{prompt}' β†’ generating {target_chars} characters with {steps} steps...")
135
+
136
+ try:
137
+ # Generate longer continuation
138
+ continuation_bits = target_chars * 9 # 9 bits per character
139
+ generated_bits = diffusion_inference(
140
+ model,
141
+ length=continuation_bits,
142
+ steps=steps,
143
+ batch_size=1,
144
+ init_bits=None,
145
+ schedule="cosine"
146
+ )
147
+
148
+ # Decode result
149
+ continuation_bits_list = generated_bits.squeeze().tolist()
150
+ continuation_text = bits_to_text(continuation_bits_list)
151
+
152
+ full_result = prompt + continuation_text
153
+ print(f" βœ… SUCCESS: '{full_result}'")
154
+
155
+ return {
156
+ 'success': True,
157
+ 'full_text': full_result,
158
+ 'new_text': continuation_text,
159
+ 'bits_generated': len(continuation_bits_list),
160
+ 'diffusion_steps': steps
161
+ }
162
+
163
+ except Exception as e:
164
+ print(f" ❌ FAILED: {e}")
165
+ return {
166
+ 'success': False,
167
+ 'error': str(e),
168
+ 'diffusion_steps': steps
169
+ }
170
+
171
+ def test_length_scaling():
172
+ """Test if longer generations produce more coherent results."""
173
+ print("\nπŸ“ === LENGTH SCALING TESTS ===")
174
+ print("Testing if longer generations show improved coherence...")
175
+
176
+ model = load_full_attention_model()
177
+ test_prompts = ["Hello", "The weather today", "I think that"]
178
+ target_lengths = [10, 25, 50]
179
+
180
+ results = []
181
+
182
+ for prompt in test_prompts:
183
+ for length in target_lengths:
184
+ print(f"\n--- Testing '{prompt}' β†’ {length} chars ---")
185
+
186
+ # Test autoregressive
187
+ auto_result = autoregressive_generate_with_parity_correction(
188
+ model, prompt, max_new_chars=length, temperature=0.6
189
+ )
190
+
191
+ # Test diffusion with high steps
192
+ diff_result = long_diffusion_generation(
193
+ model, prompt, target_chars=length, steps=50
194
+ )
195
+
196
+ results.append({
197
+ 'prompt': prompt,
198
+ 'target_length': length,
199
+ 'autoregressive': auto_result,
200
+ 'diffusion': diff_result
201
+ })
202
+
203
+ return results
204
+
205
+ def test_parameter_optimization():
206
+ """Test different generation parameters for quality."""
207
+ print("\nβš™οΈ === PARAMETER OPTIMIZATION TESTS ===")
208
+ print("Testing different temperatures and diffusion steps...")
209
+
210
+ model = load_full_attention_model()
211
+ prompt = "Hello world"
212
+
213
+ results = []
214
+
215
+ # Test different temperatures for autoregressive
216
+ print("\n🌑️ Testing autoregressive temperatures:")
217
+ for temp in [0.1, 0.5, 0.8, 1.0, 1.2]:
218
+ print(f"\n--- Temperature {temp} ---")
219
+ result = autoregressive_generate_with_parity_correction(
220
+ model, prompt, max_new_chars=20, temperature=temp
221
+ )
222
+ results.append({
223
+ 'method': 'autoregressive',
224
+ 'temperature': temp,
225
+ 'result': result
226
+ })
227
+
228
+ # Test different diffusion steps
229
+ print("\n🌊 Testing diffusion steps:")
230
+ for steps in [10, 25, 50, 100]:
231
+ print(f"\n--- {steps} steps ---")
232
+ result = long_diffusion_generation(
233
+ model, prompt, target_chars=20, steps=steps
234
+ )
235
+ results.append({
236
+ 'method': 'diffusion',
237
+ 'steps': steps,
238
+ 'result': result
239
+ })
240
+
241
+ return results
242
+
243
+ def test_coherence_prompts():
244
+ """Test with prompts that should elicit more coherent responses."""
245
+ print("\n🎯 === COHERENCE PROMPTS TESTS ===")
246
+ print("Testing prompts designed to elicit coherent language patterns...")
247
+
248
+ model = load_full_attention_model()
249
+
250
+ # Prompts that might elicit more structured responses
251
+ coherence_prompts = [
252
+ "Once upon a time",
253
+ "The quick brown fox",
254
+ "In the beginning",
255
+ "Python code to print hello:",
256
+ "def main():",
257
+ "SELECT * FROM",
258
+ "Today is a beautiful",
259
+ "My name is",
260
+ "The answer is",
261
+ "import torch"
262
+ ]
263
+
264
+ results = []
265
+
266
+ for prompt in coherence_prompts:
267
+ print(f"\n--- Testing coherence with: '{prompt}' ---")
268
+
269
+ # Test both methods with longer generation
270
+ auto_result = autoregressive_generate_with_parity_correction(
271
+ model, prompt, max_new_chars=30, temperature=0.7
272
+ )
273
+
274
+ diff_result = long_diffusion_generation(
275
+ model, prompt, target_chars=30, steps=75
276
+ )
277
+
278
+ results.append({
279
+ 'prompt': prompt,
280
+ 'autoregressive': auto_result,
281
+ 'diffusion': diff_result
282
+ })
283
+
284
+ # Quick analysis
285
+ if auto_result.get('success'):
286
+ auto_text = auto_result.get('new_text', '')
287
+ if any(word in auto_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']):
288
+ print(f" πŸŽ‰ Autoregressive contains common words!")
289
+
290
+ if diff_result.get('success'):
291
+ diff_text = diff_result.get('new_text', '')
292
+ if any(word in diff_text.lower() for word in ['the', 'and', 'is', 'in', 'to', 'a']):
293
+ print(f" πŸŽ‰ Diffusion contains common words!")
294
+
295
+ return results
296
+
297
+ def main():
298
+ """Run all enhanced generation tests."""
299
+ print("πŸš€ ENHANCED BITRANSFORMERLM GENERATION TESTING")
300
+ print("=" * 60)
301
+ print("Testing potential fixes:")
302
+ print("1. Autoregressive with parity correction")
303
+ print("2. Longer sequence generation")
304
+ print("3. Optimized generation parameters")
305
+ print("4. Coherence-focused prompts")
306
+ print("=" * 60)
307
+
308
+ # Run all tests
309
+ length_results = test_length_scaling()
310
+ param_results = test_parameter_optimization()
311
+ coherence_results = test_coherence_prompts()
312
+
313
+ # Summary analysis
314
+ print("\n🎯 === OVERALL ANALYSIS ===")
315
+
316
+ # Count successes
317
+ total_auto = len([r for results in [length_results, coherence_results]
318
+ for r in results if r.get('autoregressive', {}).get('success')])
319
+ total_diff = len([r for results in [length_results, coherence_results]
320
+ for r in results if r.get('diffusion', {}).get('success')])
321
+
322
+ print(f"Autoregressive success rate: {total_auto}/24")
323
+ print(f"Diffusion success rate: {total_diff}/24")
324
+
325
+ # Look for promising outputs
326
+ print("\nπŸ” Looking for signs of linguistic improvement...")
327
+
328
+ all_results = length_results + coherence_results
329
+ promising_outputs = []
330
+
331
+ for result in all_results:
332
+ for method in ['autoregressive', 'diffusion']:
333
+ if result.get(method, {}).get('success'):
334
+ text = result[method].get('new_text', '')
335
+ # Check for word-like patterns
336
+ if len(text) > 10 and any(c.isalpha() for c in text):
337
+ words = text.split()
338
+ if any(len(word) > 2 and word.isalpha() for word in words):
339
+ promising_outputs.append({
340
+ 'prompt': result['prompt'],
341
+ 'method': method,
342
+ 'text': text
343
+ })
344
+
345
+ if promising_outputs:
346
+ print(f"\nπŸŽ‰ Found {len(promising_outputs)} promising outputs with word-like patterns!")
347
+ for output in promising_outputs[:5]: # Show first 5
348
+ print(f" {output['method']}: '{output['prompt']}' β†’ '{output['text']}'")
349
+ else:
350
+ print("\nπŸ’­ No clear word patterns found yet - model may need more training or different approach")
351
+
352
+ return {
353
+ 'length_results': length_results,
354
+ 'param_results': param_results,
355
+ 'coherence_results': coherence_results,
356
+ 'summary': {
357
+ 'autoregressive_successes': total_auto,
358
+ 'diffusion_successes': total_diff,
359
+ 'promising_outputs': len(promising_outputs)
360
+ }
361
+ }
362
+
363
+ if __name__ == "__main__":
364
+ results = main()