WCNegentropy commited on
Commit
bde6dbb
Β·
verified Β·
1 Parent(s): 9202d01

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. scripts/testing/diffusion_tests.py +484 -0
scripts/testing/diffusion_tests.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BitTransformerLM Denoising Diffusion Inference Tests
4
+ ====================================================
5
+
6
+ Test the breakthrough model using built-in denoising diffusion generation
7
+ to potentially resolve parity errors and improve text quality.
8
+ """
9
+
10
+ import sys
11
+ import torch
12
+ import math
13
+ import logging
14
+
15
+ # Add paths for imports
16
+ sys.path.append('/data')
17
+ sys.path.append('/data/BitTransformerLM')
18
+
19
+ from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text, diffusion_inference
20
+
21
+ # Setup logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def load_breakthrough_model():
26
+ """Load the trained breakthrough BitTransformerLM."""
27
+ print("πŸš€ Loading breakthrough BitTransformerLM for diffusion inference...")
28
+
29
+ # Create model with EXACT same config as training
30
+ model = BitTransformerLM(
31
+ d_model=512,
32
+ nhead=16,
33
+ num_layers=8,
34
+ dim_feedforward=1024,
35
+ max_seq_len=512,
36
+ reversible=True,
37
+ use_checkpoint=False, # Disable for inference
38
+ use_autocast=False, # Disable for inference
39
+ use_act=True,
40
+ act_threshold=0.9,
41
+ lambda_K=0.05,
42
+ lambda_C=0.05,
43
+ lambda_S=0.05
44
+ )
45
+
46
+ # Load the breakthrough checkpoint
47
+ checkpoint = torch.load('/data/BitTransformerLM/checkpoints/checkpoint_best.pt', map_location='cpu')
48
+ model.load_state_dict(checkpoint['model_state_dict'])
49
+ model.eval()
50
+
51
+ print(f"βœ… Model loaded! Loss: {checkpoint['loss']:.6f}, Epoch: {checkpoint['epoch']}")
52
+
53
+ total_params = sum(p.numel() for p in model.parameters())
54
+ print(f"πŸ“Š Parameters: {total_params:,}")
55
+
56
+ return model
57
+
58
+ def test_basic_diffusion_generation(model):
59
+ """Test basic diffusion generation without conditioning."""
60
+ print("\nπŸ§ͺ === BASIC DIFFUSION GENERATION TESTS ===")
61
+
62
+ test_configs = [
63
+ {"length": 36, "steps": 8, "schedule": "linear", "name": "4 chars, linear"},
64
+ {"length": 45, "steps": 12, "schedule": "cosine", "name": "5 chars, cosine"},
65
+ {"length": 54, "steps": 16, "schedule": "exp", "name": "6 chars, exp"},
66
+ ]
67
+
68
+ results = []
69
+
70
+ for config in test_configs:
71
+ print(f"\n--- {config['name']} ---")
72
+ print(f"Config: {config['length']} bits, {config['steps']} steps, {config['schedule']} schedule")
73
+
74
+ try:
75
+ # Generate using diffusion inference
76
+ generated_bits = diffusion_inference(
77
+ model,
78
+ length=config['length'],
79
+ steps=config['steps'],
80
+ schedule=config['schedule']
81
+ )
82
+
83
+ # Convert to list for processing
84
+ bits_list = generated_bits.squeeze().tolist()
85
+ print(f"Generated {len(bits_list)} bits: {bits_list[:18]}...")
86
+
87
+ # Try to decode
88
+ try:
89
+ text = bits_to_text(bits_list)
90
+ print(f"βœ… SUCCESS: '{text}'")
91
+ results.append({"config": config, "text": text, "success": True})
92
+ except Exception as decode_error:
93
+ print(f"❌ Decode failed: {decode_error}")
94
+
95
+ # Try manual character decode
96
+ manual_text = ""
97
+ for i in range(0, len(bits_list), 9):
98
+ if i + 8 < len(bits_list):
99
+ char_bits = bits_list[i:i+8]
100
+ byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
101
+ if 32 <= byte_val <= 126:
102
+ manual_text += chr(byte_val)
103
+ else:
104
+ manual_text += '?'
105
+
106
+ print(f"πŸ”§ Manual decode: '{manual_text}'")
107
+ results.append({"config": config, "text": manual_text, "success": False})
108
+
109
+ except Exception as e:
110
+ print(f"πŸ’₯ Generation failed: {e}")
111
+ results.append({"config": config, "text": None, "success": False, "error": str(e)})
112
+
113
+ return results
114
+
115
+ def test_conditioned_diffusion_generation(model):
116
+ """Test diffusion generation conditioned on prompts."""
117
+ print("\n🎯 === CONDITIONED DIFFUSION GENERATION TESTS ===")
118
+
119
+ prompts = [
120
+ "Hello",
121
+ "Hi there",
122
+ "What is your name?",
123
+ "The weather is",
124
+ "I am",
125
+ "Yes",
126
+ "No"
127
+ ]
128
+
129
+ results = []
130
+
131
+ for prompt in prompts:
132
+ print(f"\n--- Prompt: '{prompt}' ---")
133
+
134
+ # Convert prompt to bits
135
+ prompt_bits = text_to_bits(prompt)
136
+ print(f"Prompt: {len(prompt_bits)} bits")
137
+
138
+ # Generate continuation (prompt + generation)
139
+ total_length = len(prompt_bits) + 45 # prompt + 5 characters
140
+
141
+ # Create initial bits with prompt + noise
142
+ init_bits = torch.zeros(1, total_length, dtype=torch.long)
143
+ init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
144
+ init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (total_length - len(prompt_bits),))
145
+
146
+ try:
147
+ # Use diffusion inference with initialization
148
+ generated_bits = diffusion_inference(
149
+ model,
150
+ length=total_length,
151
+ steps=12,
152
+ init_bits=init_bits,
153
+ schedule="cosine"
154
+ )
155
+
156
+ # Extract just the generated part
157
+ full_bits = generated_bits.squeeze().tolist()
158
+ generated_only = full_bits[len(prompt_bits):]
159
+
160
+ print(f"Generated {len(generated_only)} bits for continuation")
161
+
162
+ # Try to decode the continuation
163
+ try:
164
+ continuation = bits_to_text(generated_only)
165
+ full_result = prompt + continuation
166
+ print(f"βœ… SUCCESS: '{prompt}' β†’ '{full_result}'")
167
+ results.append({
168
+ "prompt": prompt,
169
+ "continuation": continuation,
170
+ "full_result": full_result,
171
+ "success": True
172
+ })
173
+ except Exception as decode_error:
174
+ print(f"❌ Decode failed: {decode_error}")
175
+
176
+ # Manual decode
177
+ manual_continuation = ""
178
+ for i in range(0, len(generated_only), 9):
179
+ if i + 8 < len(generated_only):
180
+ char_bits = generated_only[i:i+8]
181
+ byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
182
+ if 32 <= byte_val <= 126:
183
+ manual_continuation += chr(byte_val)
184
+ else:
185
+ manual_continuation += '?'
186
+
187
+ full_result = prompt + manual_continuation
188
+ print(f"πŸ”§ Manual decode: '{prompt}' β†’ '{full_result}'")
189
+ results.append({
190
+ "prompt": prompt,
191
+ "continuation": manual_continuation,
192
+ "full_result": full_result,
193
+ "success": False
194
+ })
195
+
196
+ except Exception as e:
197
+ print(f"πŸ’₯ Generation failed: {e}")
198
+ results.append({
199
+ "prompt": prompt,
200
+ "continuation": None,
201
+ "full_result": None,
202
+ "success": False,
203
+ "error": str(e)
204
+ })
205
+
206
+ return results
207
+
208
+ def test_code_diffusion_completion(model):
209
+ """Test diffusion generation on code/math completion."""
210
+ print("\nπŸ’» === CODE DIFFUSION COMPLETION TESTS ===")
211
+
212
+ code_prompts = [
213
+ # Math
214
+ "2 + 2 =",
215
+ "1 + 1 =",
216
+ "5 * 3 =",
217
+ "10 / 2 =",
218
+
219
+ # Programming
220
+ "def hello():",
221
+ "if x ==",
222
+ "for i in",
223
+ "print(",
224
+ "return",
225
+
226
+ # Patterns
227
+ "a, b, c,",
228
+ "1, 2, 3,",
229
+ "function(",
230
+ "var x =",
231
+ ]
232
+
233
+ results = []
234
+
235
+ for prompt in code_prompts:
236
+ print(f"\n--- Code: '{prompt}' ---")
237
+
238
+ prompt_bits = text_to_bits(prompt)
239
+ print(f"Prompt: {len(prompt_bits)} bits")
240
+
241
+ # Generate shorter completions for code
242
+ completion_length = 36 # 4 characters
243
+ total_length = len(prompt_bits) + completion_length
244
+
245
+ # Initialize with prompt + noise
246
+ init_bits = torch.zeros(1, total_length, dtype=torch.long)
247
+ init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
248
+ init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (completion_length,))
249
+
250
+ try:
251
+ # Use exponential schedule for sharper code completions
252
+ generated_bits = diffusion_inference(
253
+ model,
254
+ length=total_length,
255
+ steps=16, # More steps for better quality
256
+ init_bits=init_bits,
257
+ schedule="exp"
258
+ )
259
+
260
+ # Extract completion
261
+ full_bits = generated_bits.squeeze().tolist()
262
+ completion_bits = full_bits[len(prompt_bits):]
263
+
264
+ # Try to decode
265
+ try:
266
+ completion = bits_to_text(completion_bits)
267
+ full_result = prompt + completion
268
+ print(f"βœ… SUCCESS: '{prompt}' β†’ '{full_result}'")
269
+
270
+ # Analyze completion quality for code
271
+ analysis = []
272
+ if any(c.isalnum() for c in completion):
273
+ analysis.append("Contains alphanumeric")
274
+ if any(c in "0123456789" for c in completion):
275
+ analysis.append("Contains numbers")
276
+ if any(c in "=(){}[];," for c in completion):
277
+ analysis.append("Contains code symbols")
278
+ if any(c in " \n\t" for c in completion):
279
+ analysis.append("Contains whitespace")
280
+
281
+ if analysis:
282
+ print(f" πŸ“Š Analysis: {', '.join(analysis)}")
283
+
284
+ results.append({
285
+ "prompt": prompt,
286
+ "completion": completion,
287
+ "full_result": full_result,
288
+ "analysis": analysis,
289
+ "success": True
290
+ })
291
+
292
+ except Exception as decode_error:
293
+ print(f"❌ Decode failed: {decode_error}")
294
+
295
+ # Manual decode with analysis
296
+ manual_completion = ""
297
+ char_types = {"letters": 0, "numbers": 0, "symbols": 0, "printable": 0}
298
+
299
+ for i in range(0, len(completion_bits), 9):
300
+ if i + 8 < len(completion_bits):
301
+ char_bits = completion_bits[i:i+8]
302
+ byte_val = sum(bit * (2**(7-j)) for j, bit in enumerate(char_bits))
303
+ if 32 <= byte_val <= 126:
304
+ char = chr(byte_val)
305
+ manual_completion += char
306
+ char_types["printable"] += 1
307
+ if char.isalpha():
308
+ char_types["letters"] += 1
309
+ elif char.isdigit():
310
+ char_types["numbers"] += 1
311
+ elif char in "=(){}[];,+-*/<>!@#$%^&":
312
+ char_types["symbols"] += 1
313
+ else:
314
+ manual_completion += '?'
315
+
316
+ full_result = prompt + manual_completion
317
+ print(f"πŸ”§ Manual decode: '{prompt}' β†’ '{full_result}'")
318
+ print(f" πŸ“Š Character types: {char_types}")
319
+
320
+ results.append({
321
+ "prompt": prompt,
322
+ "completion": manual_completion,
323
+ "full_result": full_result,
324
+ "char_types": char_types,
325
+ "success": False
326
+ })
327
+
328
+ except Exception as e:
329
+ print(f"πŸ’₯ Generation failed: {e}")
330
+ results.append({
331
+ "prompt": prompt,
332
+ "completion": None,
333
+ "full_result": None,
334
+ "success": False,
335
+ "error": str(e)
336
+ })
337
+
338
+ return results
339
+
340
+ def compare_diffusion_vs_autoregressive(model):
341
+ """Compare diffusion vs autoregressive generation quality."""
342
+ print("\nβš–οΈ === DIFFUSION vs AUTOREGRESSIVE COMPARISON ===")
343
+
344
+ test_prompts = ["Hello", "Hi", "The cat", "Yes"]
345
+ comparison_results = []
346
+
347
+ for prompt in test_prompts:
348
+ print(f"\n--- Comparing generation for: '{prompt}' ---")
349
+
350
+ prompt_bits = text_to_bits(prompt)
351
+ generation_length = 27 # 3 characters
352
+
353
+ # AUTOREGRESSIVE GENERATION (previous method)
354
+ print("πŸ”„ Autoregressive generation:")
355
+ try:
356
+ generated_bits_ar = prompt_bits.copy()
357
+
358
+ with torch.no_grad():
359
+ for i in range(generation_length):
360
+ context = generated_bits_ar[-300:] if len(generated_bits_ar) > 300 else generated_bits_ar
361
+ context_tensor = torch.tensor(context, dtype=torch.long).unsqueeze(0)
362
+
363
+ logits, _ = model(context_tensor) # causal=True by default
364
+ next_bit_logits = logits[0, -1, :]
365
+
366
+ # Temperature sampling
367
+ next_bit_logits = next_bit_logits / 0.8
368
+ probs = torch.softmax(next_bit_logits, dim=-1)
369
+ next_bit = torch.multinomial(probs, 1).item()
370
+
371
+ generated_bits_ar.append(next_bit)
372
+
373
+ ar_completion_bits = generated_bits_ar[len(prompt_bits):]
374
+ try:
375
+ ar_completion = bits_to_text(ar_completion_bits)
376
+ ar_success = True
377
+ except:
378
+ ar_completion = "DECODE_FAILED"
379
+ ar_success = False
380
+
381
+ print(f" Result: '{prompt}' β†’ '{prompt + ar_completion}' (Success: {ar_success})")
382
+
383
+ except Exception as e:
384
+ ar_completion = f"ERROR: {e}"
385
+ ar_success = False
386
+ print(f" Result: ERROR - {e}")
387
+
388
+ # DIFFUSION GENERATION
389
+ print("🌊 Diffusion generation:")
390
+ try:
391
+ total_length = len(prompt_bits) + generation_length
392
+ init_bits = torch.zeros(1, total_length, dtype=torch.long)
393
+ init_bits[0, :len(prompt_bits)] = torch.tensor(prompt_bits, dtype=torch.long)
394
+ init_bits[0, len(prompt_bits):] = torch.randint(0, 2, (generation_length,))
395
+
396
+ generated_bits_diff = diffusion_inference(
397
+ model,
398
+ length=total_length,
399
+ steps=12,
400
+ init_bits=init_bits,
401
+ schedule="cosine"
402
+ )
403
+
404
+ diff_completion_bits = generated_bits_diff.squeeze().tolist()[len(prompt_bits):]
405
+ try:
406
+ diff_completion = bits_to_text(diff_completion_bits)
407
+ diff_success = True
408
+ except:
409
+ diff_completion = "DECODE_FAILED"
410
+ diff_success = False
411
+
412
+ print(f" Result: '{prompt}' β†’ '{prompt + diff_completion}' (Success: {diff_success})")
413
+
414
+ except Exception as e:
415
+ diff_completion = f"ERROR: {e}"
416
+ diff_success = False
417
+ print(f" Result: ERROR - {e}")
418
+
419
+ # Store comparison
420
+ comparison_results.append({
421
+ "prompt": prompt,
422
+ "autoregressive": {"completion": ar_completion, "success": ar_success},
423
+ "diffusion": {"completion": diff_completion, "success": diff_success}
424
+ })
425
+
426
+ # Quick quality assessment
427
+ if diff_success and ar_success:
428
+ print(f" πŸ† Both methods succeeded!")
429
+ elif diff_success and not ar_success:
430
+ print(f" 🌊 Diffusion wins - only it succeeded!")
431
+ elif ar_success and not diff_success:
432
+ print(f" πŸ”„ Autoregressive wins - only it succeeded!")
433
+ else:
434
+ print(f" 😞 Both methods failed")
435
+
436
+ return comparison_results
437
+
438
+ def main():
439
+ """Run all diffusion inference tests."""
440
+ print("πŸš€ BITRANSFORMERLM DENOISING DIFFUSION INFERENCE TESTS")
441
+ print("=" * 70)
442
+ print("Testing hypothesis: Denoising diffusion should reduce parity errors")
443
+ print("by treating parity bits as noise and filtering them out.")
444
+ print("=" * 70)
445
+
446
+ # Load model
447
+ model = load_breakthrough_model()
448
+
449
+ # Run all tests
450
+ test_results = {
451
+ "basic_diffusion": test_basic_diffusion_generation(model),
452
+ "conditioned_diffusion": test_conditioned_diffusion_generation(model),
453
+ "code_diffusion": test_code_diffusion_completion(model),
454
+ "comparison": compare_diffusion_vs_autoregressive(model),
455
+ }
456
+
457
+ print("\n🎯 === FINAL SUMMARY ===")
458
+
459
+ # Basic diffusion success rate
460
+ basic_successes = sum(1 for r in test_results["basic_diffusion"] if r.get("success", False))
461
+ print(f"Basic diffusion success rate: {basic_successes}/{len(test_results['basic_diffusion'])}")
462
+
463
+ # Conditioned diffusion success rate
464
+ cond_successes = sum(1 for r in test_results["conditioned_diffusion"] if r.get("success", False))
465
+ print(f"Conditioned diffusion success rate: {cond_successes}/{len(test_results['conditioned_diffusion'])}")
466
+
467
+ # Code diffusion success rate
468
+ code_successes = sum(1 for r in test_results["code_diffusion"] if r.get("success", False))
469
+ print(f"Code diffusion success rate: {code_successes}/{len(test_results['code_diffusion'])}")
470
+
471
+ # Comparison analysis
472
+ diff_wins = sum(1 for r in test_results["comparison"]
473
+ if r["diffusion"]["success"] and not r["autoregressive"]["success"])
474
+ ar_wins = sum(1 for r in test_results["comparison"]
475
+ if r["autoregressive"]["success"] and not r["diffusion"]["success"])
476
+ both_win = sum(1 for r in test_results["comparison"]
477
+ if r["diffusion"]["success"] and r["autoregressive"]["success"])
478
+
479
+ print(f"Method comparison - Diffusion only: {diff_wins}, Autoregressive only: {ar_wins}, Both: {both_win}")
480
+
481
+ return test_results
482
+
483
+ if __name__ == "__main__":
484
+ main()