WCNegentropy commited on
Commit
d580d32
Β·
verified Β·
1 Parent(s): 782a141

πŸš€ Refined BitTransformerLM: Organized codebase with best practices

Browse files
bit_transformer/BTLM_Extensions/__init__.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BTLM_Extensions: Extensions Package for BitTransformerLM
3
+ =======================================================
4
+
5
+ This package provides advanced optimizers and compression techniques
6
+ as extensions for BitTransformerLM, allowing easy experimentation with
7
+ different training configurations.
8
+
9
+ Available Extensions:
10
+
11
+ Optimizers:
12
+ - Muon: Orthogonal momentum optimizer with Newton-Schulz iterations
13
+ - Lion: EvoLved Sign Momentum optimizer for memory efficiency
14
+ - Adafactor: Memory-efficient factorized optimizer
15
+
16
+ Compression:
17
+ - RLE: Advanced Run-Length Encoding with multiple schemes
18
+
19
+ Usage:
20
+ from BTLM_Extensions import configure_muon_optimizer, RLEEncoder
21
+
22
+ # Use Muon optimizer
23
+ optimizer, scheduler = configure_muon_optimizer(model, lr=1e-3)
24
+
25
+ # Use RLE compression
26
+ encoder = RLEEncoder(scheme="adaptive")
27
+ compressed, metadata = encoder.encode(data)
28
+ """
29
+
30
+ __version__ = "1.0.0"
31
+ __author__ = "BitTransformerLM Extensions"
32
+ __email__ = "extensions@bittransformerlm.ai"
33
+
34
+ # Import all optimizers
35
+ from .muon_optimizer import (
36
+ Muon,
37
+ configure_muon_optimizer,
38
+ create_muon_training_config,
39
+ )
40
+
41
+ from .lion_optimizer import (
42
+ Lion,
43
+ AdaptiveLion,
44
+ configure_lion_optimizer,
45
+ configure_adaptive_lion_optimizer,
46
+ create_lion_training_config,
47
+ )
48
+
49
+ from .adafactor_optimizer import (
50
+ Adafactor,
51
+ AdafactorScheduler,
52
+ configure_adafactor_optimizer,
53
+ configure_adafactor_with_scheduler,
54
+ create_adafactor_training_config,
55
+ analyze_memory_usage,
56
+ )
57
+
58
+ # Import compression utilities
59
+ from .rle_compression import (
60
+ RLEEncoder,
61
+ CompressedBitDataset,
62
+ create_compression_aware_loss,
63
+ integrate_rle_with_training,
64
+ benchmark_compression_schemes,
65
+ create_rle_training_config,
66
+ )
67
+
68
+ # Convenience functions for easy optimizer swapping
69
+ def get_optimizer_config(optimizer_type: str, **kwargs):
70
+ """
71
+ Get configuration for specified optimizer type.
72
+
73
+ Args:
74
+ optimizer_type: Type of optimizer ('muon', 'lion', 'adafactor')
75
+ **kwargs: Optimizer-specific parameters
76
+
77
+ Returns:
78
+ Dictionary with optimizer configuration
79
+ """
80
+ if optimizer_type.lower() == "muon":
81
+ return create_muon_training_config(**kwargs)
82
+ elif optimizer_type.lower() == "lion":
83
+ return create_lion_training_config(**kwargs)
84
+ elif optimizer_type.lower() == "adafactor":
85
+ return create_adafactor_training_config(**kwargs)
86
+ else:
87
+ raise ValueError(f"Unknown optimizer type: {optimizer_type}")
88
+
89
+
90
+ def configure_optimizer(optimizer_type: str, model, **kwargs):
91
+ """
92
+ Configure optimizer based on type string.
93
+
94
+ Args:
95
+ optimizer_type: Type of optimizer ('muon', 'lion', 'adafactor')
96
+ model: PyTorch model to optimize
97
+ **kwargs: Optimizer-specific parameters
98
+
99
+ Returns:
100
+ Tuple of (optimizer, scheduler)
101
+ """
102
+ if optimizer_type.lower() == "muon":
103
+ return configure_muon_optimizer(model, **kwargs)
104
+ elif optimizer_type.lower() == "lion":
105
+ return configure_lion_optimizer(model, **kwargs)
106
+ elif optimizer_type.lower() == "adafactor":
107
+ return configure_adafactor_optimizer(model, **kwargs)
108
+ else:
109
+ raise ValueError(f"Unknown optimizer type: {optimizer_type}")
110
+
111
+
112
+ # Integration helpers for BitTransformerLM
113
+ class ExtensionManager:
114
+ """
115
+ Manager class for easy integration with BitTransformerLM.
116
+
117
+ Provides unified interface for switching between optimizers
118
+ and compression schemes.
119
+ """
120
+
121
+ SUPPORTED_OPTIMIZERS = ["muon", "lion", "adafactor"]
122
+ SUPPORTED_COMPRESSION = ["rle"]
123
+
124
+ def __init__(self):
125
+ self.current_optimizer = None
126
+ self.current_compression = None
127
+
128
+ def setup_optimizer(self, optimizer_type: str, model, **kwargs):
129
+ """Setup optimizer for training."""
130
+ if optimizer_type not in self.SUPPORTED_OPTIMIZERS:
131
+ raise ValueError(f"Unsupported optimizer: {optimizer_type}")
132
+
133
+ optimizer, scheduler = configure_optimizer(optimizer_type, model, **kwargs)
134
+ self.current_optimizer = optimizer_type
135
+ return optimizer, scheduler
136
+
137
+ def setup_compression(self, compression_type: str, **kwargs):
138
+ """Setup compression scheme."""
139
+ if compression_type not in self.SUPPORTED_COMPRESSION:
140
+ raise ValueError(f"Unsupported compression: {compression_type}")
141
+
142
+ if compression_type == "rle":
143
+ encoder = RLEEncoder(**kwargs)
144
+ self.current_compression = compression_type
145
+ return encoder
146
+
147
+ def create_training_config(self, optimizer_type: str = "muon", compression_type: str = "rle", **kwargs):
148
+ """Create comprehensive training configuration."""
149
+ config = {
150
+ "optimizer": get_optimizer_config(optimizer_type, **kwargs),
151
+ "compression": create_rle_training_config(**kwargs) if compression_type == "rle" else None,
152
+ "extensions": {
153
+ "optimizer_type": optimizer_type,
154
+ "compression_type": compression_type,
155
+ "version": __version__,
156
+ }
157
+ }
158
+ return config
159
+
160
+ def benchmark_optimizers(self, model, test_data, epochs: int = 5):
161
+ """Benchmark all available optimizers on test data."""
162
+ import torch
163
+ import torch.nn.functional as F
164
+ import time
165
+
166
+ results = {}
167
+
168
+ for opt_type in self.SUPPORTED_OPTIMIZERS:
169
+ print(f"Benchmarking {opt_type} optimizer...")
170
+
171
+ # Create fresh model copy
172
+ model_copy = type(model)(**model._current_params())
173
+ model_copy.load_state_dict(model.state_dict())
174
+
175
+ try:
176
+ # Setup optimizer
177
+ optimizer, scheduler = self.setup_optimizer(opt_type, model_copy, lr=1e-3)
178
+
179
+ # Training loop
180
+ start_time = time.time()
181
+ losses = []
182
+
183
+ for epoch in range(epochs):
184
+ optimizer.zero_grad()
185
+
186
+ # Simple forward pass
187
+ logits, _ = model_copy(test_data)
188
+ pred = logits[:, :-1, :].reshape(-1, 2)
189
+ target = test_data[:, 1:].reshape(-1)
190
+ loss = F.cross_entropy(pred, target)
191
+
192
+ loss.backward()
193
+ optimizer.step()
194
+ if scheduler:
195
+ scheduler.step()
196
+
197
+ losses.append(loss.item())
198
+
199
+ end_time = time.time()
200
+
201
+ results[opt_type] = {
202
+ "final_loss": losses[-1],
203
+ "avg_loss": sum(losses) / len(losses),
204
+ "training_time": end_time - start_time,
205
+ "convergence": losses[0] - losses[-1],
206
+ "success": True,
207
+ }
208
+
209
+ except Exception as e:
210
+ results[opt_type] = {
211
+ "final_loss": float('inf'),
212
+ "avg_loss": float('inf'),
213
+ "training_time": 0,
214
+ "convergence": 0,
215
+ "success": False,
216
+ "error": str(e),
217
+ }
218
+
219
+ return results
220
+
221
+
222
+ # Create global extension manager instance
223
+ extension_manager = ExtensionManager()
224
+
225
+ # Export all important symbols
226
+ __all__ = [
227
+ # Optimizers
228
+ "Muon",
229
+ "Lion",
230
+ "AdaptiveLion",
231
+ "Adafactor",
232
+ "AdafactorScheduler",
233
+
234
+ # Optimizer configuration functions
235
+ "configure_muon_optimizer",
236
+ "configure_lion_optimizer",
237
+ "configure_adaptive_lion_optimizer",
238
+ "configure_adafactor_optimizer",
239
+ "configure_adafactor_with_scheduler",
240
+
241
+ # Training configuration creators
242
+ "create_muon_training_config",
243
+ "create_lion_training_config",
244
+ "create_adafactor_training_config",
245
+
246
+ # Compression
247
+ "RLEEncoder",
248
+ "CompressedBitDataset",
249
+ "create_compression_aware_loss",
250
+ "integrate_rle_with_training",
251
+ "benchmark_compression_schemes",
252
+ "create_rle_training_config",
253
+
254
+ # Convenience functions
255
+ "get_optimizer_config",
256
+ "configure_optimizer",
257
+ "ExtensionManager",
258
+ "extension_manager",
259
+ "analyze_memory_usage",
260
+ ]
261
+
262
+ # Package information
263
+ def get_version():
264
+ """Get package version."""
265
+ return __version__
266
+
267
+ def list_optimizers():
268
+ """List all available optimizers."""
269
+ return ExtensionManager.SUPPORTED_OPTIMIZERS.copy()
270
+
271
+ def list_compression_schemes():
272
+ """List all available compression schemes."""
273
+ return ExtensionManager.SUPPORTED_COMPRESSION.copy()
274
+
275
+ def get_package_info():
276
+ """Get package information."""
277
+ return {
278
+ "name": "BTLM_Extensions",
279
+ "version": __version__,
280
+ "author": __author__,
281
+ "email": __email__,
282
+ "optimizers": list_optimizers(),
283
+ "compression": list_compression_schemes(),
284
+ "description": "Advanced optimizers and compression for BitTransformerLM",
285
+ }
286
+
287
+ # Print welcome message when imported
288
+ def _welcome_message():
289
+ """Print welcome message with available extensions."""
290
+ print(f"πŸš€ BTLM_Extensions v{__version__} loaded!")
291
+ print(f"πŸ“Š Available optimizers: {', '.join(list_optimizers())}")
292
+ print(f"πŸ—œοΈ Available compression: {', '.join(list_compression_schemes())}")
293
+ print("πŸ“– Use help(BTLM_Extensions) for detailed documentation")
294
+
295
+ # Uncomment the line below if you want the welcome message on import
296
+ # _welcome_message()
297
+
298
+ # Demonstrate usage example in docstring
299
+ def demo_usage():
300
+ """
301
+ Demonstration of BTLM_Extensions usage:
302
+
303
+ # Quick optimizer swap
304
+ from BTLM_Extensions import configure_optimizer
305
+
306
+ # Try different optimizers
307
+ muon_opt, muon_sched = configure_optimizer("muon", model, lr=1e-3)
308
+ lion_opt, lion_sched = configure_optimizer("lion", model, lr=1e-4)
309
+ adafactor_opt, adafactor_sched = configure_optimizer("adafactor", model)
310
+
311
+ # Use with BitTransformerLM training
312
+ from bit_transformer.training import train_loop
313
+
314
+ train_loop(model, data, optimizer=muon_opt, scheduler=muon_sched)
315
+
316
+ # Advanced compression
317
+ from BTLM_Extensions import RLEEncoder, integrate_rle_with_training
318
+
319
+ # Setup compression-aware training
320
+ dataset, loss_fn = integrate_rle_with_training(model, data)
321
+
322
+ # Benchmark optimizers
323
+ from BTLM_Extensions import extension_manager
324
+
325
+ results = extension_manager.benchmark_optimizers(model, test_data)
326
+ print("Benchmark results:", results)
327
+ """
328
+ pass