WCNegentropy commited on
Commit
4fb71c6
·
verified ·
1 Parent(s): 2f70b79

🚀 Refined BitTransformerLM: Organized codebase with best practices

Browse files
Files changed (1) hide show
  1. bit_transformer/config.py +323 -0
bit_transformer/config.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management for BitTransformerLM."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import torch
11
+
12
+ from .types import (
13
+ AttentionMask,
14
+ ChunkSize,
15
+ DeviceType,
16
+ DiffusionConfig,
17
+ GenerationConfig,
18
+ HiddenSize,
19
+ NumHeads,
20
+ NumLayers,
21
+ QuantizationConfig,
22
+ SafetyThresholds,
23
+ SequenceLength,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class ModelConfig:
29
+ """Configuration for BitTransformerLM model architecture.
30
+
31
+ Attributes:
32
+ d_model: Model dimension for embeddings and attention.
33
+ nhead: Number of attention heads.
34
+ num_layers: Number of transformer layers.
35
+ dim_feedforward: Dimension of feedforward networks.
36
+ max_seq_len: Maximum sequence length for positional encoding.
37
+ lambda_K: Weight for negentropy metric in telemetry.
38
+ lambda_C: Weight for complexity metric in telemetry.
39
+ lambda_S: Weight for symbiosis metric in telemetry.
40
+ reversible: Enable reversible layers for memory efficiency.
41
+ use_checkpoint: Use gradient checkpointing.
42
+ use_autocast: Use automatic mixed precision.
43
+ use_act: Enable Adaptive Computation Time.
44
+ act_threshold: ACT halting threshold.
45
+ chunk_size: Chunk size for chunked attention (None for full attention).
46
+ overlap: Overlap size for chunked attention.
47
+ full_attn_logging: Log full attention matrices for telemetry.
48
+ """
49
+
50
+ d_model: HiddenSize = 128
51
+ nhead: NumHeads = 8
52
+ num_layers: NumLayers = 4
53
+ dim_feedforward: int = 512
54
+ max_seq_len: SequenceLength = 1024
55
+ lambda_K: float = 1.0
56
+ lambda_C: float = 1.0
57
+ lambda_S: float = 1.0
58
+ reversible: bool = False
59
+ use_checkpoint: bool = True
60
+ use_autocast: bool = False
61
+ use_act: bool = False
62
+ act_threshold: float = 0.9
63
+ chunk_size: ChunkSize = None
64
+ overlap: int = 0
65
+ full_attn_logging: Optional[bool] = None
66
+
67
+ def to_dict(self) -> Dict[str, Any]:
68
+ """Convert config to dictionary."""
69
+ return {
70
+ "d_model": self.d_model,
71
+ "nhead": self.nhead,
72
+ "num_layers": self.num_layers,
73
+ "dim_feedforward": self.dim_feedforward,
74
+ "max_seq_len": self.max_seq_len,
75
+ "lambda_K": self.lambda_K,
76
+ "lambda_C": self.lambda_C,
77
+ "lambda_S": self.lambda_S,
78
+ "reversible": self.reversible,
79
+ "use_checkpoint": self.use_checkpoint,
80
+ "use_autocast": self.use_autocast,
81
+ "use_act": self.use_act,
82
+ "act_threshold": self.act_threshold,
83
+ "chunk_size": self.chunk_size,
84
+ "overlap": self.overlap,
85
+ "full_attn_logging": self.full_attn_logging,
86
+ }
87
+
88
+ @classmethod
89
+ def from_dict(cls, config_dict: Dict[str, Any]) -> ModelConfig:
90
+ """Create config from dictionary."""
91
+ return cls(**config_dict)
92
+
93
+
94
+ @dataclass
95
+ class TrainingConfig:
96
+ """Configuration for training BitTransformerLM.
97
+
98
+ Attributes:
99
+ epochs: Number of training epochs.
100
+ batch_size: Training batch size.
101
+ learning_rate: Initial learning rate.
102
+ weight_decay: Weight decay for regularization.
103
+ gradient_clip_val: Gradient clipping value.
104
+ warmup_steps: Number of warmup steps for learning rate.
105
+ accumulate_grad_batches: Number of gradient accumulation steps.
106
+ amp: Enable automatic mixed precision.
107
+ compile_model: Enable PyTorch 2.0 compilation.
108
+ log_every_n_steps: Logging frequency.
109
+ val_check_interval: Validation check frequency.
110
+ save_top_k: Number of best checkpoints to save.
111
+ """
112
+
113
+ epochs: int = 10
114
+ batch_size: int = 8
115
+ learning_rate: float = 1e-3
116
+ weight_decay: float = 0.01
117
+ gradient_clip_val: float = 1.0
118
+ warmup_steps: int = 100
119
+ accumulate_grad_batches: int = 1
120
+ amp: bool = False
121
+ compile_model: bool = False
122
+ log_every_n_steps: int = 50
123
+ val_check_interval: float = 1.0
124
+ save_top_k: int = 3
125
+
126
+
127
+ @dataclass
128
+ class SafetyConfig:
129
+ """Configuration for safety monitoring and thresholds.
130
+
131
+ Attributes:
132
+ enable_safety: Enable safety monitoring.
133
+ k_threshold: Negentropy threshold for safety gate.
134
+ c_threshold: Complexity threshold for safety gate.
135
+ s_threshold: Symbiosis threshold for safety gate.
136
+ strict_mode: Enable strict safety enforcement.
137
+ retry_attempts: Number of retry attempts for failed safety checks.
138
+ """
139
+
140
+ enable_safety: bool = True
141
+ k_threshold: float = 0.1
142
+ c_threshold: float = 0.3
143
+ s_threshold: float = 0.5
144
+ strict_mode: bool = False
145
+ retry_attempts: int = 3
146
+
147
+ def to_thresholds(self) -> SafetyThresholds:
148
+ """Convert to SafetyThresholds type."""
149
+ return {
150
+ "k_threshold": self.k_threshold,
151
+ "c_threshold": self.c_threshold,
152
+ "s_threshold": self.s_threshold,
153
+ }
154
+
155
+
156
+ @dataclass
157
+ class DataConfig:
158
+ """Configuration for data processing and loading.
159
+
160
+ Attributes:
161
+ dataset_path: Path to training dataset.
162
+ val_dataset_path: Path to validation dataset.
163
+ num_workers: Number of data loader workers.
164
+ pin_memory: Pin memory for data loading.
165
+ prefetch_factor: Prefetch factor for data loading.
166
+ max_sequence_length: Maximum sequence length to process.
167
+ compression_prob: Probability of using compressed data.
168
+ use_parity: Enable parity bit protection.
169
+ """
170
+
171
+ dataset_path: Optional[Path] = None
172
+ val_dataset_path: Optional[Path] = None
173
+ num_workers: int = 0
174
+ pin_memory: bool = True
175
+ prefetch_factor: int = 2
176
+ max_sequence_length: int = 1024
177
+ compression_prob: float = 0.5
178
+ use_parity: bool = True
179
+
180
+
181
+ @dataclass
182
+ class ExperimentConfig:
183
+ """Complete configuration for BitTransformerLM experiments.
184
+
185
+ Attributes:
186
+ model: Model configuration.
187
+ training: Training configuration.
188
+ safety: Safety configuration.
189
+ data: Data configuration.
190
+ device: Target device for training.
191
+ seed: Random seed for reproducibility.
192
+ experiment_name: Name of the experiment.
193
+ output_dir: Directory for saving outputs.
194
+ resume_from_checkpoint: Path to checkpoint to resume from.
195
+ """
196
+
197
+ model: ModelConfig = field(default_factory=ModelConfig)
198
+ training: TrainingConfig = field(default_factory=TrainingConfig)
199
+ safety: SafetyConfig = field(default_factory=SafetyConfig)
200
+ data: DataConfig = field(default_factory=DataConfig)
201
+ device: DeviceType = "auto"
202
+ seed: int = 42
203
+ experiment_name: str = "bit_transformer_experiment"
204
+ output_dir: Path = Path("./outputs")
205
+ resume_from_checkpoint: Optional[Path] = None
206
+
207
+ def __post_init__(self):
208
+ """Post-initialization to handle device selection and path creation."""
209
+ # Auto-detect device
210
+ if self.device == "auto":
211
+ if torch.cuda.is_available():
212
+ self.device = "cuda"
213
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
214
+ self.device = "mps"
215
+ else:
216
+ self.device = "cpu"
217
+
218
+ # Ensure output directory exists
219
+ self.output_dir.mkdir(parents=True, exist_ok=True)
220
+
221
+ def to_dict(self) -> Dict[str, Any]:
222
+ """Convert complete config to dictionary."""
223
+ return {
224
+ "model": self.model.to_dict(),
225
+ "training": self.training.__dict__,
226
+ "safety": self.safety.__dict__,
227
+ "data": self.data.__dict__,
228
+ "device": str(self.device),
229
+ "seed": self.seed,
230
+ "experiment_name": self.experiment_name,
231
+ "output_dir": str(self.output_dir),
232
+ "resume_from_checkpoint": str(self.resume_from_checkpoint) if self.resume_from_checkpoint else None,
233
+ }
234
+
235
+
236
+ # Preset configurations for common use cases
237
+ def get_small_config() -> ExperimentConfig:
238
+ """Get configuration for small-scale experiments."""
239
+ return ExperimentConfig(
240
+ model=ModelConfig(
241
+ d_model=64,
242
+ nhead=4,
243
+ num_layers=2,
244
+ dim_feedforward=256,
245
+ max_seq_len=256,
246
+ ),
247
+ training=TrainingConfig(
248
+ batch_size=4,
249
+ learning_rate=1e-3,
250
+ epochs=5,
251
+ ),
252
+ )
253
+
254
+
255
+ def get_medium_config() -> ExperimentConfig:
256
+ """Get configuration for medium-scale experiments."""
257
+ return ExperimentConfig(
258
+ model=ModelConfig(
259
+ d_model=128,
260
+ nhead=8,
261
+ num_layers=4,
262
+ dim_feedforward=512,
263
+ max_seq_len=1024,
264
+ ),
265
+ training=TrainingConfig(
266
+ batch_size=8,
267
+ learning_rate=1e-3,
268
+ epochs=10,
269
+ ),
270
+ )
271
+
272
+
273
+ def get_large_config() -> ExperimentConfig:
274
+ """Get configuration for large-scale experiments."""
275
+ return ExperimentConfig(
276
+ model=ModelConfig(
277
+ d_model=256,
278
+ nhead=16,
279
+ num_layers=8,
280
+ dim_feedforward=1024,
281
+ max_seq_len=2048,
282
+ reversible=True,
283
+ chunk_size=512,
284
+ ),
285
+ training=TrainingConfig(
286
+ batch_size=16,
287
+ learning_rate=5e-4,
288
+ epochs=20,
289
+ amp=True,
290
+ compile_model=True,
291
+ ),
292
+ )
293
+
294
+
295
+ def get_config_from_env() -> ExperimentConfig:
296
+ """Load configuration from environment variables."""
297
+ config = ExperimentConfig()
298
+
299
+ # Model config from environment
300
+ if os.getenv("BT_D_MODEL"):
301
+ config.model.d_model = int(os.getenv("BT_D_MODEL"))
302
+ if os.getenv("BT_NUM_LAYERS"):
303
+ config.model.num_layers = int(os.getenv("BT_NUM_LAYERS"))
304
+ if os.getenv("BT_NHEAD"):
305
+ config.model.nhead = int(os.getenv("BT_NHEAD"))
306
+
307
+ # Training config from environment
308
+ if os.getenv("BT_BATCH_SIZE"):
309
+ config.training.batch_size = int(os.getenv("BT_BATCH_SIZE"))
310
+ if os.getenv("BT_LEARNING_RATE"):
311
+ config.training.learning_rate = float(os.getenv("BT_LEARNING_RATE"))
312
+ if os.getenv("BT_EPOCHS"):
313
+ config.training.epochs = int(os.getenv("BT_EPOCHS"))
314
+
315
+ # Device from environment
316
+ if os.getenv("BT_DEVICE"):
317
+ config.device = os.getenv("BT_DEVICE")
318
+
319
+ # Output directory from environment
320
+ if os.getenv("BT_OUTPUT_DIR"):
321
+ config.output_dir = Path(os.getenv("BT_OUTPUT_DIR"))
322
+
323
+ return config