WCNegentropy commited on
Commit
2ca4b28
Β·
verified Β·
1 Parent(s): d1e4760

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/test_extensions.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for BTLM_Extensions
4
+ ===============================
5
+
6
+ Quick test to verify all extensions are working properly.
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ # Add paths for imports
15
+ sys.path.append('/data')
16
+ sys.path.append('/data/BitTransformerLM')
17
+
18
+ def test_imports():
19
+ """Test that all modules can be imported."""
20
+ print("Testing imports...")
21
+
22
+ try:
23
+ from BTLM_Extensions import (
24
+ Muon, Lion, Adafactor,
25
+ configure_muon_optimizer,
26
+ configure_lion_optimizer,
27
+ configure_adafactor_optimizer,
28
+ RLEEncoder,
29
+ extension_manager,
30
+ get_package_info
31
+ )
32
+ print("βœ… All imports successful")
33
+ return True
34
+ except Exception as e:
35
+ print(f"❌ Import failed: {e}")
36
+ return False
37
+
38
+ def test_optimizers():
39
+ """Test optimizer functionality."""
40
+ print("\nTesting optimizers...")
41
+
42
+ # Create a simple model
43
+ model = nn.Sequential(
44
+ nn.Linear(10, 20),
45
+ nn.ReLU(),
46
+ nn.Linear(20, 2)
47
+ )
48
+
49
+ try:
50
+ from BTLM_Extensions import (
51
+ configure_muon_optimizer,
52
+ configure_lion_optimizer,
53
+ configure_adafactor_optimizer
54
+ )
55
+
56
+ # Test each optimizer
57
+ optimizers_to_test = [
58
+ ("muon", configure_muon_optimizer, {"lr": 1e-3}),
59
+ ("lion", configure_lion_optimizer, {"lr": 1e-4}),
60
+ ("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}),
61
+ ]
62
+
63
+ for name, config_fn, kwargs in optimizers_to_test:
64
+ try:
65
+ optimizer, scheduler = config_fn(model, total_steps=100, **kwargs)
66
+
67
+ # Test a training step
68
+ x = torch.randn(4, 10)
69
+ y = torch.randint(0, 2, (4,))
70
+
71
+ pred = model(x)
72
+ loss = nn.functional.cross_entropy(pred, y)
73
+ loss.backward()
74
+
75
+ optimizer.step()
76
+ if scheduler:
77
+ scheduler.step()
78
+ optimizer.zero_grad()
79
+
80
+ print(f"βœ… {name.capitalize()} optimizer working")
81
+
82
+ except Exception as e:
83
+ print(f"❌ {name.capitalize()} optimizer failed: {e}")
84
+
85
+ return True
86
+
87
+ except Exception as e:
88
+ print(f"❌ Optimizer test failed: {e}")
89
+ return False
90
+
91
+ def test_rle_compression():
92
+ """Test RLE compression."""
93
+ print("\nTesting RLE compression...")
94
+
95
+ try:
96
+ from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes
97
+
98
+ # Create test data with patterns
99
+ test_data = torch.randint(0, 2, (50,))
100
+ # Add some runs for better compression
101
+ test_data[10:20] = 1
102
+ test_data[30:40] = 0
103
+
104
+ # Test different schemes
105
+ schemes = ["basic", "delta", "adaptive"]
106
+
107
+ for scheme in schemes:
108
+ try:
109
+ encoder = RLEEncoder(scheme=scheme)
110
+ compressed, metadata = encoder.encode(test_data)
111
+ reconstructed = encoder.decode(compressed, metadata)
112
+
113
+ # Check reconstruction
114
+ error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
115
+
116
+ if error.item() < 1e-6:
117
+ print(f"βœ… RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})")
118
+ else:
119
+ print(f"❌ RLE {scheme} scheme reconstruction error: {error.item()}")
120
+
121
+ except Exception as e:
122
+ print(f"❌ RLE {scheme} scheme failed: {e}")
123
+
124
+ # Test benchmark function
125
+ try:
126
+ results = benchmark_compression_schemes(test_data)
127
+ print(f"βœ… RLE benchmark completed ({len(results)} schemes tested)")
128
+ except Exception as e:
129
+ print(f"❌ RLE benchmark failed: {e}")
130
+
131
+ return True
132
+
133
+ except Exception as e:
134
+ print(f"❌ RLE compression test failed: {e}")
135
+ return False
136
+
137
+ def test_integration():
138
+ """Test integration features."""
139
+ print("\nTesting integration features...")
140
+
141
+ try:
142
+ from BTLM_Extensions import extension_manager, get_package_info
143
+
144
+ # Test package info
145
+ info = get_package_info()
146
+ print(f"βœ… Package info: {info['name']} v{info['version']}")
147
+
148
+ # Test extension manager
149
+ optimizers = extension_manager.SUPPORTED_OPTIMIZERS
150
+ compression = extension_manager.SUPPORTED_COMPRESSION
151
+ print(f"βœ… Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes")
152
+
153
+ return True
154
+
155
+ except Exception as e:
156
+ print(f"❌ Integration test failed: {e}")
157
+ return False
158
+
159
+ def test_bittransformerlm_integration():
160
+ """Test integration with BitTransformerLM if available."""
161
+ print("\nTesting BitTransformerLM integration...")
162
+
163
+ try:
164
+ from bit_transformer import BitTransformerLM
165
+ from BTLM_Extensions import configure_optimizer
166
+
167
+ # Create a small BitTransformerLM model
168
+ model = BitTransformerLM(
169
+ d_model=64,
170
+ nhead=4,
171
+ num_layers=2,
172
+ dim_feedforward=128,
173
+ max_seq_len=32
174
+ )
175
+
176
+ # Test optimizer integration
177
+ optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10)
178
+
179
+ # Simple forward pass
180
+ test_bits = torch.randint(0, 2, (2, 16))
181
+ logits, telemetry = model(test_bits)
182
+
183
+ # Simple training step
184
+ pred = logits[:, :-1, :].reshape(-1, 2)
185
+ target = test_bits[:, 1:].reshape(-1)
186
+ loss = nn.functional.cross_entropy(pred, target)
187
+
188
+ loss.backward()
189
+ optimizer.step()
190
+ if scheduler:
191
+ scheduler.step()
192
+
193
+ print(f"βœ… BitTransformerLM integration working (loss: {loss.item():.4f})")
194
+ return True
195
+
196
+ except ImportError:
197
+ print("⚠️ BitTransformerLM not available, skipping integration test")
198
+ return True
199
+ except Exception as e:
200
+ print(f"❌ BitTransformerLM integration failed: {e}")
201
+ return False
202
+
203
+ def main():
204
+ """Run all tests."""
205
+ print("BTLM_Extensions Test Suite")
206
+ print("=" * 40)
207
+
208
+ tests = [
209
+ test_imports,
210
+ test_optimizers,
211
+ test_rle_compression,
212
+ test_integration,
213
+ test_bittransformerlm_integration,
214
+ ]
215
+
216
+ passed = 0
217
+ total = len(tests)
218
+
219
+ for test in tests:
220
+ try:
221
+ if test():
222
+ passed += 1
223
+ except Exception as e:
224
+ print(f"❌ Test {test.__name__} crashed: {e}")
225
+
226
+ print("\n" + "=" * 40)
227
+ print(f"Test Results: {passed}/{total} passed")
228
+
229
+ if passed == total:
230
+ print("πŸŽ‰ All tests passed! Extensions are working correctly.")
231
+ return 0
232
+ else:
233
+ print("⚠️ Some tests failed. Check the output above.")
234
+ return 1
235
+
236
+ if __name__ == "__main__":
237
+ exit_code = main()
238
+ sys.exit(exit_code)