|
|
"""Configuration management for BitTransformerLM.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from .types import ( |
|
|
AttentionMask, |
|
|
ChunkSize, |
|
|
DeviceType, |
|
|
DiffusionConfig, |
|
|
GenerationConfig, |
|
|
HiddenSize, |
|
|
NumHeads, |
|
|
NumLayers, |
|
|
QuantizationConfig, |
|
|
SafetyThresholds, |
|
|
SequenceLength, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Configuration for BitTransformerLM model architecture. |
|
|
|
|
|
Attributes: |
|
|
d_model: Model dimension for embeddings and attention. |
|
|
nhead: Number of attention heads. |
|
|
num_layers: Number of transformer layers. |
|
|
dim_feedforward: Dimension of feedforward networks. |
|
|
max_seq_len: Maximum sequence length for positional encoding. |
|
|
lambda_K: Weight for negentropy metric in telemetry. |
|
|
lambda_C: Weight for complexity metric in telemetry. |
|
|
lambda_S: Weight for symbiosis metric in telemetry. |
|
|
reversible: Enable reversible layers for memory efficiency. |
|
|
use_checkpoint: Use gradient checkpointing. |
|
|
use_autocast: Use automatic mixed precision. |
|
|
use_act: Enable Adaptive Computation Time. |
|
|
act_threshold: ACT halting threshold. |
|
|
chunk_size: Chunk size for chunked attention (None for full attention). |
|
|
overlap: Overlap size for chunked attention. |
|
|
full_attn_logging: Log full attention matrices for telemetry. |
|
|
""" |
|
|
|
|
|
d_model: HiddenSize = 128 |
|
|
nhead: NumHeads = 8 |
|
|
num_layers: NumLayers = 4 |
|
|
dim_feedforward: int = 512 |
|
|
max_seq_len: SequenceLength = 1024 |
|
|
lambda_K: float = 1.0 |
|
|
lambda_C: float = 1.0 |
|
|
lambda_S: float = 1.0 |
|
|
reversible: bool = False |
|
|
use_checkpoint: bool = True |
|
|
use_autocast: bool = False |
|
|
use_act: bool = False |
|
|
act_threshold: float = 0.9 |
|
|
chunk_size: ChunkSize = None |
|
|
overlap: int = 0 |
|
|
full_attn_logging: Optional[bool] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert config to dictionary.""" |
|
|
return { |
|
|
"d_model": self.d_model, |
|
|
"nhead": self.nhead, |
|
|
"num_layers": self.num_layers, |
|
|
"dim_feedforward": self.dim_feedforward, |
|
|
"max_seq_len": self.max_seq_len, |
|
|
"lambda_K": self.lambda_K, |
|
|
"lambda_C": self.lambda_C, |
|
|
"lambda_S": self.lambda_S, |
|
|
"reversible": self.reversible, |
|
|
"use_checkpoint": self.use_checkpoint, |
|
|
"use_autocast": self.use_autocast, |
|
|
"use_act": self.use_act, |
|
|
"act_threshold": self.act_threshold, |
|
|
"chunk_size": self.chunk_size, |
|
|
"overlap": self.overlap, |
|
|
"full_attn_logging": self.full_attn_logging, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config_dict: Dict[str, Any]) -> ModelConfig: |
|
|
"""Create config from dictionary.""" |
|
|
return cls(**config_dict) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration for training BitTransformerLM. |
|
|
|
|
|
Attributes: |
|
|
epochs: Number of training epochs. |
|
|
batch_size: Training batch size. |
|
|
learning_rate: Initial learning rate. |
|
|
weight_decay: Weight decay for regularization. |
|
|
gradient_clip_val: Gradient clipping value. |
|
|
warmup_steps: Number of warmup steps for learning rate. |
|
|
accumulate_grad_batches: Number of gradient accumulation steps. |
|
|
amp: Enable automatic mixed precision. |
|
|
compile_model: Enable PyTorch 2.0 compilation. |
|
|
log_every_n_steps: Logging frequency. |
|
|
val_check_interval: Validation check frequency. |
|
|
save_top_k: Number of best checkpoints to save. |
|
|
""" |
|
|
|
|
|
epochs: int = 10 |
|
|
batch_size: int = 8 |
|
|
learning_rate: float = 1e-3 |
|
|
weight_decay: float = 0.01 |
|
|
gradient_clip_val: float = 1.0 |
|
|
warmup_steps: int = 100 |
|
|
accumulate_grad_batches: int = 1 |
|
|
amp: bool = False |
|
|
compile_model: bool = False |
|
|
log_every_n_steps: int = 50 |
|
|
val_check_interval: float = 1.0 |
|
|
save_top_k: int = 3 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SafetyConfig: |
|
|
"""Configuration for safety monitoring and thresholds. |
|
|
|
|
|
Attributes: |
|
|
enable_safety: Enable safety monitoring. |
|
|
k_threshold: Negentropy threshold for safety gate. |
|
|
c_threshold: Complexity threshold for safety gate. |
|
|
s_threshold: Symbiosis threshold for safety gate. |
|
|
strict_mode: Enable strict safety enforcement. |
|
|
retry_attempts: Number of retry attempts for failed safety checks. |
|
|
""" |
|
|
|
|
|
enable_safety: bool = True |
|
|
k_threshold: float = 0.1 |
|
|
c_threshold: float = 0.3 |
|
|
s_threshold: float = 0.5 |
|
|
strict_mode: bool = False |
|
|
retry_attempts: int = 3 |
|
|
|
|
|
def to_thresholds(self) -> SafetyThresholds: |
|
|
"""Convert to SafetyThresholds type.""" |
|
|
return { |
|
|
"k_threshold": self.k_threshold, |
|
|
"c_threshold": self.c_threshold, |
|
|
"s_threshold": self.s_threshold, |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataConfig: |
|
|
"""Configuration for data processing and loading. |
|
|
|
|
|
Attributes: |
|
|
dataset_path: Path to training dataset. |
|
|
val_dataset_path: Path to validation dataset. |
|
|
num_workers: Number of data loader workers. |
|
|
pin_memory: Pin memory for data loading. |
|
|
prefetch_factor: Prefetch factor for data loading. |
|
|
max_sequence_length: Maximum sequence length to process. |
|
|
compression_prob: Probability of using compressed data. |
|
|
use_parity: Enable parity bit protection. |
|
|
""" |
|
|
|
|
|
dataset_path: Optional[Path] = None |
|
|
val_dataset_path: Optional[Path] = None |
|
|
num_workers: int = 0 |
|
|
pin_memory: bool = True |
|
|
prefetch_factor: int = 2 |
|
|
max_sequence_length: int = 1024 |
|
|
compression_prob: float = 0.5 |
|
|
use_parity: bool = True |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExperimentConfig: |
|
|
"""Complete configuration for BitTransformerLM experiments. |
|
|
|
|
|
Attributes: |
|
|
model: Model configuration. |
|
|
training: Training configuration. |
|
|
safety: Safety configuration. |
|
|
data: Data configuration. |
|
|
device: Target device for training. |
|
|
seed: Random seed for reproducibility. |
|
|
experiment_name: Name of the experiment. |
|
|
output_dir: Directory for saving outputs. |
|
|
resume_from_checkpoint: Path to checkpoint to resume from. |
|
|
""" |
|
|
|
|
|
model: ModelConfig = field(default_factory=ModelConfig) |
|
|
training: TrainingConfig = field(default_factory=TrainingConfig) |
|
|
safety: SafetyConfig = field(default_factory=SafetyConfig) |
|
|
data: DataConfig = field(default_factory=DataConfig) |
|
|
device: DeviceType = "auto" |
|
|
seed: int = 42 |
|
|
experiment_name: str = "bit_transformer_experiment" |
|
|
output_dir: Path = Path("./outputs") |
|
|
resume_from_checkpoint: Optional[Path] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Post-initialization to handle device selection and path creation.""" |
|
|
|
|
|
if self.device == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
self.device = "cuda" |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
self.device = "mps" |
|
|
else: |
|
|
self.device = "cpu" |
|
|
|
|
|
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert complete config to dictionary.""" |
|
|
return { |
|
|
"model": self.model.to_dict(), |
|
|
"training": self.training.__dict__, |
|
|
"safety": self.safety.__dict__, |
|
|
"data": self.data.__dict__, |
|
|
"device": str(self.device), |
|
|
"seed": self.seed, |
|
|
"experiment_name": self.experiment_name, |
|
|
"output_dir": str(self.output_dir), |
|
|
"resume_from_checkpoint": str(self.resume_from_checkpoint) if self.resume_from_checkpoint else None, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def get_small_config() -> ExperimentConfig: |
|
|
"""Get configuration for small-scale experiments.""" |
|
|
return ExperimentConfig( |
|
|
model=ModelConfig( |
|
|
d_model=64, |
|
|
nhead=4, |
|
|
num_layers=2, |
|
|
dim_feedforward=256, |
|
|
max_seq_len=256, |
|
|
), |
|
|
training=TrainingConfig( |
|
|
batch_size=4, |
|
|
learning_rate=1e-3, |
|
|
epochs=5, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_medium_config() -> ExperimentConfig: |
|
|
"""Get configuration for medium-scale experiments.""" |
|
|
return ExperimentConfig( |
|
|
model=ModelConfig( |
|
|
d_model=128, |
|
|
nhead=8, |
|
|
num_layers=4, |
|
|
dim_feedforward=512, |
|
|
max_seq_len=1024, |
|
|
), |
|
|
training=TrainingConfig( |
|
|
batch_size=8, |
|
|
learning_rate=1e-3, |
|
|
epochs=10, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_large_config() -> ExperimentConfig: |
|
|
"""Get configuration for large-scale experiments.""" |
|
|
return ExperimentConfig( |
|
|
model=ModelConfig( |
|
|
d_model=256, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=2048, |
|
|
reversible=True, |
|
|
chunk_size=512, |
|
|
), |
|
|
training=TrainingConfig( |
|
|
batch_size=16, |
|
|
learning_rate=5e-4, |
|
|
epochs=20, |
|
|
amp=True, |
|
|
compile_model=True, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def get_config_from_env() -> ExperimentConfig: |
|
|
"""Load configuration from environment variables.""" |
|
|
config = ExperimentConfig() |
|
|
|
|
|
|
|
|
if os.getenv("BT_D_MODEL"): |
|
|
config.model.d_model = int(os.getenv("BT_D_MODEL")) |
|
|
if os.getenv("BT_NUM_LAYERS"): |
|
|
config.model.num_layers = int(os.getenv("BT_NUM_LAYERS")) |
|
|
if os.getenv("BT_NHEAD"): |
|
|
config.model.nhead = int(os.getenv("BT_NHEAD")) |
|
|
|
|
|
|
|
|
if os.getenv("BT_BATCH_SIZE"): |
|
|
config.training.batch_size = int(os.getenv("BT_BATCH_SIZE")) |
|
|
if os.getenv("BT_LEARNING_RATE"): |
|
|
config.training.learning_rate = float(os.getenv("BT_LEARNING_RATE")) |
|
|
if os.getenv("BT_EPOCHS"): |
|
|
config.training.epochs = int(os.getenv("BT_EPOCHS")) |
|
|
|
|
|
|
|
|
if os.getenv("BT_DEVICE"): |
|
|
config.device = os.getenv("BT_DEVICE") |
|
|
|
|
|
|
|
|
if os.getenv("BT_OUTPUT_DIR"): |
|
|
config.output_dir = Path(os.getenv("BT_OUTPUT_DIR")) |
|
|
|
|
|
return config |