WCNegentropy commited on
Commit
ba1fad5
Β·
verified Β·
1 Parent(s): 72bc506

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
scripts/testing/full_attention_inference_test.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full Attention BitTransformerLM Diffusion Inference Test
4
+ ========================================================
5
+
6
+ Test the newly trained full bi-directional attention BitTransformerLM model
7
+ using denoising diffusion generation to evaluate improvements from full attention training.
8
+
9
+ Model Configuration:
10
+ - Same full bi-directional unchunked attention as training (chunk_size=None)
11
+ - Proper eval() mode with dropout management
12
+ - Use latest checkpoint_best.pt from full attention training
13
+ - Test with same diffusion inference that worked before
14
+ """
15
+
16
+ import sys
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from datetime import datetime
20
+
21
+ sys.path.append('/data')
22
+ sys.path.append('/data/BitTransformerLM')
23
+
24
+ from bit_transformer import (
25
+ BitTransformerLM,
26
+ text_to_bits,
27
+ bits_to_text,
28
+ diffusion_inference,
29
+ set_dropout
30
+ )
31
+
32
+ def load_full_attention_model():
33
+ """Load the newly trained full attention BitTransformerLM model."""
34
+ print("πŸš€ Loading Full Attention BitTransformerLM for diffusion inference...")
35
+
36
+ # Create model with SAME configuration as full attention training
37
+ model = BitTransformerLM(
38
+ d_model=512, # Same as training
39
+ nhead=16, # Same as training
40
+ num_layers=8, # Same as training
41
+ dim_feedforward=1024, # Same as training
42
+ max_seq_len=512, # Same as training
43
+ reversible=True, # Same as training
44
+ use_checkpoint=False, # Disable for inference
45
+ use_autocast=False, # Disable for inference
46
+ use_act=True, # Same as training
47
+ act_threshold=0.9, # Same as training
48
+ lambda_K=0.05, # Same as training
49
+ lambda_C=0.05, # Same as training
50
+ lambda_S=0.05, # Same as training
51
+ chunk_size=None, # FULL ATTENTION - same as training
52
+ overlap=0, # Same as training
53
+ full_attn_logging=True # Same as training
54
+ )
55
+
56
+ # Load the latest checkpoint_best.pt (should be from full attention training)
57
+ checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_best.pt'
58
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
59
+ model.load_state_dict(checkpoint['model_state_dict'])
60
+
61
+ # Set to evaluation mode with proper dropout
62
+ model.eval()
63
+ set_dropout(model, 0.0) # Disable dropout for inference
64
+
65
+ # Get checkpoint info
66
+ epoch = checkpoint.get('epoch', 'unknown')
67
+ loss = checkpoint.get('loss', 'unknown')
68
+
69
+ print(f"βœ… Full Attention Model loaded! Epoch: {epoch}, Loss: {loss}")
70
+
71
+ # Calculate parameters
72
+ total_params = sum(p.numel() for p in model.parameters())
73
+ print(f"πŸ“Š Parameters: {total_params:,}")
74
+
75
+ return model
76
+
77
+ def test_basic_diffusion_generation(model):
78
+ """Test basic unconditional diffusion generation."""
79
+ print("\nπŸ§ͺ === BASIC FULL ATTENTION DIFFUSION GENERATION ===")
80
+
81
+ results = []
82
+
83
+ test_configs = [
84
+ {"length": 36, "steps": 8, "schedule": "linear"},
85
+ {"length": 45, "steps": 12, "schedule": "cosine"},
86
+ {"length": 54, "steps": 16, "schedule": "exp"}
87
+ ]
88
+
89
+ for i, config in enumerate(test_configs, 1):
90
+ print(f"\n--- Test {i}: {config['length']//9} chars, {config['schedule']} ---")
91
+
92
+ try:
93
+ # Generate with diffusion
94
+ generated_bits = diffusion_inference(
95
+ model,
96
+ length=config['length'],
97
+ steps=config['steps'],
98
+ batch_size=1,
99
+ schedule=config['schedule']
100
+ )
101
+
102
+ # Try to decode
103
+ bit_list = generated_bits.squeeze().tolist()
104
+ decoded_text = bits_to_text(bit_list)
105
+
106
+ print(f"βœ… SUCCESS: '{decoded_text}'")
107
+ results.append({
108
+ "test": f"basic_{i}",
109
+ "config": config,
110
+ "success": True,
111
+ "output": decoded_text,
112
+ "bits": len(bit_list)
113
+ })
114
+
115
+ except Exception as e:
116
+ print(f"❌ FAILED: {e}")
117
+ results.append({
118
+ "test": f"basic_{i}",
119
+ "config": config,
120
+ "success": False,
121
+ "error": str(e)
122
+ })
123
+
124
+ return results
125
+
126
+ def test_conditioned_diffusion_generation(model):
127
+ """Test prompt-conditioned diffusion generation."""
128
+ print("\n🎯 === CONDITIONED FULL ATTENTION DIFFUSION GENERATION ===")
129
+
130
+ results = []
131
+
132
+ test_prompts = [
133
+ "Hello",
134
+ "Hi there",
135
+ "What is your name?",
136
+ "The weather is",
137
+ "I am",
138
+ "Yes",
139
+ "No"
140
+ ]
141
+
142
+ for prompt in test_prompts:
143
+ print(f"\n--- Prompt: '{prompt}' ---")
144
+
145
+ try:
146
+ # Convert prompt to bits
147
+ prompt_bits = text_to_bits(prompt)
148
+
149
+ # Generate continuation with diffusion (no init_bits - let it generate freely)
150
+ continuation_length = 45 # 5 character continuation
151
+ generated_bits = diffusion_inference(
152
+ model,
153
+ length=continuation_length,
154
+ steps=12,
155
+ batch_size=1,
156
+ init_bits=None,
157
+ schedule="cosine"
158
+ )
159
+
160
+ # Combine prompt + generated continuation
161
+ full_bits = prompt_bits + generated_bits.squeeze().tolist()
162
+
163
+ # Decode continuation only
164
+ continuation_bits = generated_bits.squeeze().tolist()
165
+ continuation_text = bits_to_text(continuation_bits)
166
+
167
+ # Show combined result
168
+ combined_text = prompt + continuation_text
169
+ print(f"βœ… SUCCESS: '{prompt}' β†’ '{combined_text}'")
170
+ results.append({
171
+ "test": "conditioned",
172
+ "prompt": prompt,
173
+ "success": True,
174
+ "full_output": combined_text,
175
+ "continuation": continuation_text,
176
+ "bits": len(continuation_bits)
177
+ })
178
+
179
+ except Exception as e:
180
+ print(f"❌ FAILED: {e}")
181
+ results.append({
182
+ "test": "conditioned",
183
+ "prompt": prompt,
184
+ "success": False,
185
+ "error": str(e)
186
+ })
187
+
188
+ return results
189
+
190
+ def test_code_diffusion_completion(model):
191
+ """Test code/math completion with diffusion."""
192
+ print("\nπŸ’» === CODE COMPLETION FULL ATTENTION DIFFUSION ===")
193
+
194
+ results = []
195
+
196
+ test_cases = [
197
+ # Math equations
198
+ "2 + 2 =",
199
+ "1 + 1 =",
200
+ "5 * 3 =",
201
+ "10 / 2 =",
202
+
203
+ # Programming constructs
204
+ "def hello():",
205
+ "if x ==",
206
+ "for i in",
207
+ "print(",
208
+ "return",
209
+
210
+ # Patterns
211
+ "a, b, c,",
212
+ "1, 2, 3,",
213
+ "function(",
214
+ "var x =",
215
+ ]
216
+
217
+ for code in test_cases:
218
+ print(f"\n--- Code: '{code}' ---")
219
+
220
+ try:
221
+ # Convert to bits
222
+ code_bits = text_to_bits(code)
223
+
224
+ # Generate completion with diffusion (no init_bits)
225
+ completion_length = 45 # 5 character completion
226
+ generated_bits = diffusion_inference(
227
+ model,
228
+ length=completion_length,
229
+ steps=10,
230
+ batch_size=1,
231
+ init_bits=None,
232
+ schedule="linear"
233
+ )
234
+
235
+ # Decode completion
236
+ completion_bits = generated_bits.squeeze().tolist()
237
+ completion = bits_to_text(completion_bits)
238
+
239
+ # Show combined result
240
+ combined_text = code + completion
241
+ print(f"βœ… SUCCESS: '{code}' β†’ '{combined_text}'")
242
+
243
+ # Analyze completion
244
+ analysis = []
245
+ if any(c.isalnum() for c in completion):
246
+ analysis.append("Contains alphanumeric")
247
+ print(f" πŸ“Š Analysis: Contains alphanumeric")
248
+ if any(c in "0123456789" for c in completion):
249
+ analysis.append("Contains numbers")
250
+ print(f" πŸ”’ Analysis: Contains numbers")
251
+ if any(c in "=(){}[];," for c in completion):
252
+ analysis.append("Contains code symbols")
253
+ print(f" πŸ’» Analysis: Contains code symbols")
254
+
255
+ results.append({
256
+ "test": "code_completion",
257
+ "prompt": code,
258
+ "success": True,
259
+ "full_output": combined_text,
260
+ "completion": completion,
261
+ "analysis": analysis,
262
+ "bits": len(completion_bits)
263
+ })
264
+
265
+ except Exception as e:
266
+ print(f"❌ FAILED: {e}")
267
+ results.append({
268
+ "test": "code_completion",
269
+ "prompt": code,
270
+ "success": False,
271
+ "error": str(e)
272
+ })
273
+
274
+ return results
275
+
276
+ def compare_with_previous_results():
277
+ """Note about comparison with previous results."""
278
+ print("\nβš–οΈ === COMPARISON WITH PREVIOUS RESULTS ===")
279
+ print("Previous chunked attention model achieved:")
280
+ print("- Basic generation: 3/3 success (100%)")
281
+ print("- Conditioned generation: 7/7 success (100%)")
282
+ print("- Code completion: 13/13 success (100%)")
283
+ print("- All diffusion inference succeeded vs 0% autoregressive")
284
+ print("\nTesting if full attention training improved quality...")
285
+
286
+ def main():
287
+ print("πŸš€ FULL ATTENTION BITRANSFORMERLM DIFFUSION INFERENCE TEST")
288
+ print("=" * 70)
289
+ print("Testing newly trained full bi-directional attention model")
290
+ print("with denoising diffusion generation")
291
+ print("=" * 70)
292
+
293
+ # Load model
294
+ model = load_full_attention_model()
295
+
296
+ # Run tests
297
+ basic_results = test_basic_diffusion_generation(model)
298
+ conditioned_results = test_conditioned_diffusion_generation(model)
299
+ code_results = test_code_diffusion_completion(model)
300
+
301
+ # Show comparison
302
+ compare_with_previous_results()
303
+
304
+ # Calculate summary stats
305
+ total_tests = len(basic_results) + len(conditioned_results) + len(code_results)
306
+ successful_tests = sum(1 for r in basic_results + conditioned_results + code_results if r.get('success', False))
307
+ success_rate = (successful_tests / total_tests) * 100 if total_tests > 0 else 0
308
+
309
+ print(f"\n🎯 === FINAL SUMMARY ===")
310
+ print(f"Total tests: {total_tests}")
311
+ print(f"Successful: {successful_tests}")
312
+ print(f"Success rate: {success_rate:.1f}%")
313
+
314
+ print(f"\nBreakdown:")
315
+ print(f"- Basic generation: {sum(1 for r in basic_results if r.get('success', False))}/{len(basic_results)}")
316
+ print(f"- Conditioned generation: {sum(1 for r in conditioned_results if r.get('success', False))}/{len(conditioned_results)}")
317
+ print(f"- Code completion: {sum(1 for r in code_results if r.get('success', False))}/{len(code_results)}")
318
+
319
+ # Return all results for documentation
320
+ return {
321
+ 'basic_results': basic_results,
322
+ 'conditioned_results': conditioned_results,
323
+ 'code_results': code_results,
324
+ 'summary': {
325
+ 'total_tests': total_tests,
326
+ 'successful_tests': successful_tests,
327
+ 'success_rate': success_rate,
328
+ 'timestamp': datetime.now().isoformat()
329
+ }
330
+ }
331
+
332
+ if __name__ == "__main__":
333
+ results = main()