WCNegentropy's picture
๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices
4fb71c6 verified
"""Configuration management for BitTransformerLM."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from .types import (
AttentionMask,
ChunkSize,
DeviceType,
DiffusionConfig,
GenerationConfig,
HiddenSize,
NumHeads,
NumLayers,
QuantizationConfig,
SafetyThresholds,
SequenceLength,
)
@dataclass
class ModelConfig:
"""Configuration for BitTransformerLM model architecture.
Attributes:
d_model: Model dimension for embeddings and attention.
nhead: Number of attention heads.
num_layers: Number of transformer layers.
dim_feedforward: Dimension of feedforward networks.
max_seq_len: Maximum sequence length for positional encoding.
lambda_K: Weight for negentropy metric in telemetry.
lambda_C: Weight for complexity metric in telemetry.
lambda_S: Weight for symbiosis metric in telemetry.
reversible: Enable reversible layers for memory efficiency.
use_checkpoint: Use gradient checkpointing.
use_autocast: Use automatic mixed precision.
use_act: Enable Adaptive Computation Time.
act_threshold: ACT halting threshold.
chunk_size: Chunk size for chunked attention (None for full attention).
overlap: Overlap size for chunked attention.
full_attn_logging: Log full attention matrices for telemetry.
"""
d_model: HiddenSize = 128
nhead: NumHeads = 8
num_layers: NumLayers = 4
dim_feedforward: int = 512
max_seq_len: SequenceLength = 1024
lambda_K: float = 1.0
lambda_C: float = 1.0
lambda_S: float = 1.0
reversible: bool = False
use_checkpoint: bool = True
use_autocast: bool = False
use_act: bool = False
act_threshold: float = 0.9
chunk_size: ChunkSize = None
overlap: int = 0
full_attn_logging: Optional[bool] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary."""
return {
"d_model": self.d_model,
"nhead": self.nhead,
"num_layers": self.num_layers,
"dim_feedforward": self.dim_feedforward,
"max_seq_len": self.max_seq_len,
"lambda_K": self.lambda_K,
"lambda_C": self.lambda_C,
"lambda_S": self.lambda_S,
"reversible": self.reversible,
"use_checkpoint": self.use_checkpoint,
"use_autocast": self.use_autocast,
"use_act": self.use_act,
"act_threshold": self.act_threshold,
"chunk_size": self.chunk_size,
"overlap": self.overlap,
"full_attn_logging": self.full_attn_logging,
}
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> ModelConfig:
"""Create config from dictionary."""
return cls(**config_dict)
@dataclass
class TrainingConfig:
"""Configuration for training BitTransformerLM.
Attributes:
epochs: Number of training epochs.
batch_size: Training batch size.
learning_rate: Initial learning rate.
weight_decay: Weight decay for regularization.
gradient_clip_val: Gradient clipping value.
warmup_steps: Number of warmup steps for learning rate.
accumulate_grad_batches: Number of gradient accumulation steps.
amp: Enable automatic mixed precision.
compile_model: Enable PyTorch 2.0 compilation.
log_every_n_steps: Logging frequency.
val_check_interval: Validation check frequency.
save_top_k: Number of best checkpoints to save.
"""
epochs: int = 10
batch_size: int = 8
learning_rate: float = 1e-3
weight_decay: float = 0.01
gradient_clip_val: float = 1.0
warmup_steps: int = 100
accumulate_grad_batches: int = 1
amp: bool = False
compile_model: bool = False
log_every_n_steps: int = 50
val_check_interval: float = 1.0
save_top_k: int = 3
@dataclass
class SafetyConfig:
"""Configuration for safety monitoring and thresholds.
Attributes:
enable_safety: Enable safety monitoring.
k_threshold: Negentropy threshold for safety gate.
c_threshold: Complexity threshold for safety gate.
s_threshold: Symbiosis threshold for safety gate.
strict_mode: Enable strict safety enforcement.
retry_attempts: Number of retry attempts for failed safety checks.
"""
enable_safety: bool = True
k_threshold: float = 0.1
c_threshold: float = 0.3
s_threshold: float = 0.5
strict_mode: bool = False
retry_attempts: int = 3
def to_thresholds(self) -> SafetyThresholds:
"""Convert to SafetyThresholds type."""
return {
"k_threshold": self.k_threshold,
"c_threshold": self.c_threshold,
"s_threshold": self.s_threshold,
}
@dataclass
class DataConfig:
"""Configuration for data processing and loading.
Attributes:
dataset_path: Path to training dataset.
val_dataset_path: Path to validation dataset.
num_workers: Number of data loader workers.
pin_memory: Pin memory for data loading.
prefetch_factor: Prefetch factor for data loading.
max_sequence_length: Maximum sequence length to process.
compression_prob: Probability of using compressed data.
use_parity: Enable parity bit protection.
"""
dataset_path: Optional[Path] = None
val_dataset_path: Optional[Path] = None
num_workers: int = 0
pin_memory: bool = True
prefetch_factor: int = 2
max_sequence_length: int = 1024
compression_prob: float = 0.5
use_parity: bool = True
@dataclass
class ExperimentConfig:
"""Complete configuration for BitTransformerLM experiments.
Attributes:
model: Model configuration.
training: Training configuration.
safety: Safety configuration.
data: Data configuration.
device: Target device for training.
seed: Random seed for reproducibility.
experiment_name: Name of the experiment.
output_dir: Directory for saving outputs.
resume_from_checkpoint: Path to checkpoint to resume from.
"""
model: ModelConfig = field(default_factory=ModelConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
safety: SafetyConfig = field(default_factory=SafetyConfig)
data: DataConfig = field(default_factory=DataConfig)
device: DeviceType = "auto"
seed: int = 42
experiment_name: str = "bit_transformer_experiment"
output_dir: Path = Path("./outputs")
resume_from_checkpoint: Optional[Path] = None
def __post_init__(self):
"""Post-initialization to handle device selection and path creation."""
# Auto-detect device
if self.device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
# Ensure output directory exists
self.output_dir.mkdir(parents=True, exist_ok=True)
def to_dict(self) -> Dict[str, Any]:
"""Convert complete config to dictionary."""
return {
"model": self.model.to_dict(),
"training": self.training.__dict__,
"safety": self.safety.__dict__,
"data": self.data.__dict__,
"device": str(self.device),
"seed": self.seed,
"experiment_name": self.experiment_name,
"output_dir": str(self.output_dir),
"resume_from_checkpoint": str(self.resume_from_checkpoint) if self.resume_from_checkpoint else None,
}
# Preset configurations for common use cases
def get_small_config() -> ExperimentConfig:
"""Get configuration for small-scale experiments."""
return ExperimentConfig(
model=ModelConfig(
d_model=64,
nhead=4,
num_layers=2,
dim_feedforward=256,
max_seq_len=256,
),
training=TrainingConfig(
batch_size=4,
learning_rate=1e-3,
epochs=5,
),
)
def get_medium_config() -> ExperimentConfig:
"""Get configuration for medium-scale experiments."""
return ExperimentConfig(
model=ModelConfig(
d_model=128,
nhead=8,
num_layers=4,
dim_feedforward=512,
max_seq_len=1024,
),
training=TrainingConfig(
batch_size=8,
learning_rate=1e-3,
epochs=10,
),
)
def get_large_config() -> ExperimentConfig:
"""Get configuration for large-scale experiments."""
return ExperimentConfig(
model=ModelConfig(
d_model=256,
nhead=16,
num_layers=8,
dim_feedforward=1024,
max_seq_len=2048,
reversible=True,
chunk_size=512,
),
training=TrainingConfig(
batch_size=16,
learning_rate=5e-4,
epochs=20,
amp=True,
compile_model=True,
),
)
def get_config_from_env() -> ExperimentConfig:
"""Load configuration from environment variables."""
config = ExperimentConfig()
# Model config from environment
if os.getenv("BT_D_MODEL"):
config.model.d_model = int(os.getenv("BT_D_MODEL"))
if os.getenv("BT_NUM_LAYERS"):
config.model.num_layers = int(os.getenv("BT_NUM_LAYERS"))
if os.getenv("BT_NHEAD"):
config.model.nhead = int(os.getenv("BT_NHEAD"))
# Training config from environment
if os.getenv("BT_BATCH_SIZE"):
config.training.batch_size = int(os.getenv("BT_BATCH_SIZE"))
if os.getenv("BT_LEARNING_RATE"):
config.training.learning_rate = float(os.getenv("BT_LEARNING_RATE"))
if os.getenv("BT_EPOCHS"):
config.training.epochs = int(os.getenv("BT_EPOCHS"))
# Device from environment
if os.getenv("BT_DEVICE"):
config.device = os.getenv("BT_DEVICE")
# Output directory from environment
if os.getenv("BT_OUTPUT_DIR"):
config.output_dir = Path(os.getenv("BT_OUTPUT_DIR"))
return config