| |
| """ |
| BitTransformerLM Massive Scale Training Script |
| ============================================== |
| |
| Scale BitTransformerLM to 1.21 BILLION parameters on extensive real corpus data. |
| This script configures distributed training across 4x NVIDIA L4 GPUs with FSDP. |
| |
| Target Configuration: |
| - Parameters: 1,208,164,352 (1.21B) |
| - Architecture: d_model=2048, layers=24, heads=32, ff=8192 |
| - Dataset: WikiText-103 + additional real corpus data |
| - Hardware: 4x NVIDIA L4 (23GB each), 181GB RAM, 48 CPU cores |
| """ |
|
|
| import os |
| import sys |
| import time |
| import math |
| import json |
| import logging |
| import argparse |
| from datetime import datetime |
| from typing import Dict, Any, Optional, List, Tuple |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch |
| from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, DistributedSampler |
| import datasets |
| from datasets import load_dataset |
| import numpy as np |
|
|
| |
| from bit_transformer.model import BitTransformerLM, LoggingTransformerEncoderLayer |
| from bit_transformer.bit_io import text_to_bits, bits_to_text |
| from bit_transformer.utils import set_dropout |
| from bit_transformer.torch_utils import cpu_autocast |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s [%(levelname)s] %(message)s', |
| handlers=[ |
| logging.FileHandler('/data/massive_scale_training.log'), |
| logging.StreamHandler(sys.stdout) |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
|
| class MassiveScaleConfig: |
| """Configuration for 680M parameter BitTransformerLM training - GPU optimized for 4x L4.""" |
| |
| |
| D_MODEL = 1536 |
| NUM_LAYERS = 24 |
| NUM_HEADS = 24 |
| DIM_FEEDFORWARD = 6144 |
| MAX_SEQ_LEN = 2048 |
| |
| |
| BATCH_SIZE_PER_GPU = 4 |
| GRADIENT_ACCUMULATION_STEPS = 32 |
| EFFECTIVE_BATCH_SIZE = BATCH_SIZE_PER_GPU * 4 * GRADIENT_ACCUMULATION_STEPS |
| |
| LEARNING_RATE = 6e-5 |
| WEIGHT_DECAY = 0.1 |
| MAX_STEPS = 50000 |
| WARMUP_STEPS = 2000 |
| |
| |
| LAMBDA_K = 1.0 |
| LAMBDA_C = 1.0 |
| LAMBDA_S = 1.0 |
| NEGENTROPY_THRESHOLD = 0.15 |
| LZ_COMPLEXITY_THRESHOLD = 0.25 |
| SYMBIOSIS_THRESHOLD = 0.4 |
| |
| |
| USE_REVERSIBLE = True |
| USE_GRADIENT_CHECKPOINTING = True |
| USE_MIXED_PRECISION = True |
| USE_SAFETY_GATES = True |
| |
| |
| DATASET_NAME = "wikitext" |
| DATASET_CONFIG = "wikitext-103-raw-v1" |
| MAX_SAMPLES = None |
| STREAMING = True |
| |
| |
| LOG_INTERVAL = 50 |
| EVAL_INTERVAL = 1000 |
| CHECKPOINT_INTERVAL = 2000 |
| |
| @classmethod |
| def get_model_config(cls) -> Dict[str, Any]: |
| """Get model configuration dictionary.""" |
| return { |
| "d_model": cls.D_MODEL, |
| "nhead": cls.NUM_HEADS, |
| "num_layers": cls.NUM_LAYERS, |
| "dim_feedforward": cls.DIM_FEEDFORWARD, |
| "max_seq_len": cls.MAX_SEQ_LEN, |
| "lambda_K": cls.LAMBDA_K, |
| "lambda_C": cls.LAMBDA_C, |
| "lambda_S": cls.LAMBDA_S, |
| "reversible": cls.USE_REVERSIBLE, |
| "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, |
| "use_autocast": False, |
| "chunk_size": None, |
| "full_attn_logging": False, |
| } |
|
|
|
|
| class WikiTextDataset(torch.utils.data.Dataset): |
| """WikiText dataset preprocessed for bit-level training.""" |
| |
| def __init__(self, split: str = "train", max_samples: Optional[int] = None, |
| max_length: int = 2048, streaming: bool = True): |
| self.max_length = max_length |
| self.streaming = streaming |
| |
| logger.info(f"Loading WikiText-103 {split} split...") |
| if streaming: |
| self.dataset = load_dataset( |
| MassiveScaleConfig.DATASET_NAME, |
| MassiveScaleConfig.DATASET_CONFIG, |
| split=split, |
| streaming=True |
| ) |
| if max_samples: |
| self.dataset = self.dataset.take(max_samples) |
| else: |
| self.dataset = load_dataset( |
| MassiveScaleConfig.DATASET_NAME, |
| MassiveScaleConfig.DATASET_CONFIG, |
| split=split |
| ) |
| if max_samples: |
| self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset)))) |
| |
| |
| if not streaming: |
| self.texts = [item['text'] for item in self.dataset if len(item['text'].strip()) > 50] |
| logger.info(f"Loaded {len(self.texts)} text samples from {split}") |
| else: |
| self.texts = None |
| logger.info(f"Streaming dataset configured for {split}") |
| |
| def __len__(self) -> int: |
| if self.texts is not None: |
| return len(self.texts) |
| else: |
| |
| return 100000 if "train" in str(self.dataset) else 1000 |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| if self.texts is not None: |
| text = self.texts[idx] |
| else: |
| |
| for i, item in enumerate(self.dataset): |
| if i == idx: |
| text = item['text'] |
| break |
| else: |
| |
| text = "The quick brown fox jumps over the lazy dog." |
| |
| |
| try: |
| bits = text_to_bits(text) |
| |
| |
| if len(bits) > self.max_length: |
| bits = bits[:self.max_length] |
| elif len(bits) < self.max_length: |
| |
| bits = bits + [0] * (self.max_length - len(bits)) |
| |
| |
| input_bits = torch.tensor(bits[:-1], dtype=torch.long) |
| target_bits = torch.tensor(bits[1:], dtype=torch.long) |
| |
| return { |
| 'input_ids': input_bits, |
| 'labels': target_bits, |
| 'attention_mask': torch.ones_like(input_bits) |
| } |
| |
| except Exception as e: |
| logger.warning(f"Error processing text at index {idx}: {e}") |
| |
| fallback_bits = [0, 1] * (self.max_length // 2) |
| if len(fallback_bits) < self.max_length: |
| fallback_bits.extend([0] * (self.max_length - len(fallback_bits))) |
| |
| input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) |
| target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) |
| |
| return { |
| 'input_ids': input_bits, |
| 'labels': target_bits, |
| 'attention_mask': torch.ones_like(input_bits) |
| } |
|
|
|
|
| def setup_distributed(rank: int, world_size: int, port: str = "29500") -> None: |
| """Initialize distributed training.""" |
| os.environ['MASTER_ADDR'] = 'localhost' |
| os.environ['MASTER_PORT'] = port |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| torch.cuda.set_device(rank) |
|
|
|
|
| def cleanup_distributed() -> None: |
| """Clean up distributed training.""" |
| dist.destroy_process_group() |
|
|
|
|
| def count_parameters(model: nn.Module) -> int: |
| """Count total trainable parameters.""" |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def create_fsdp_model(model_config: Dict[str, Any], rank: int) -> FSDP: |
| """Create FSDP-wrapped BitTransformerLM model.""" |
| |
| |
| model = BitTransformerLM(**model_config) |
| model = model.to(rank) |
| |
| |
| mixed_precision_policy = MixedPrecision( |
| param_dtype=torch.float16, |
| reduce_dtype=torch.float16, |
| buffer_dtype=torch.float16, |
| ) |
| |
| |
| auto_wrap_policy = size_based_auto_wrap_policy |
| |
| |
| model = FSDP( |
| model, |
| auto_wrap_policy=auto_wrap_policy, |
| mixed_precision=mixed_precision_policy, |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
| device_id=rank, |
| limit_all_gathers=True, |
| ) |
| |
| return model |
|
|
|
|
| def log_training_stats(step: int, loss: float, telemetry: Dict[str, float], |
| learning_rate: float, samples_per_sec: float, |
| memory_allocated: float, rank: int) -> None: |
| """Log training statistics.""" |
| if rank == 0: |
| logger.info( |
| f"Step {step:6d} | " |
| f"Loss: {loss:.4f} | " |
| f"K: {telemetry.get('negentropy', 0):.3f} | " |
| f"C: {telemetry.get('lz_complexity', 0):.3f} | " |
| f"S: {telemetry.get('symbiosis', 0):.3f} | " |
| f"LR: {learning_rate:.2e} | " |
| f"Speed: {samples_per_sec:.1f} samples/s | " |
| f"Memory: {memory_allocated:.1f}GB" |
| ) |
|
|
|
|
| def save_checkpoint(model: FSDP, optimizer, scheduler, step: int, loss: float, |
| config: MassiveScaleConfig, rank: int) -> None: |
| """Save model checkpoint.""" |
| if rank == 0: |
| checkpoint_dir = f"/data/checkpoints/massive_scale_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| |
| |
| with FSDP.state_dict_type(model, FSDP.StateDictType.FULL_STATE_DICT): |
| model_state = model.state_dict() |
| |
| checkpoint = { |
| 'step': step, |
| 'model_state_dict': model_state, |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'loss': loss, |
| 'config': config.get_model_config(), |
| 'timestamp': datetime.now().isoformat(), |
| 'parameters': count_parameters(model), |
| } |
| |
| checkpoint_path = f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Checkpoint saved: {checkpoint_path}") |
|
|
|
|
| def train_one_epoch(model: FSDP, train_loader: DataLoader, optimizer, scheduler, |
| config: MassiveScaleConfig, epoch: int, rank: int, world_size: int) -> Tuple[float, Dict[str, float]]: |
| """Train for one epoch.""" |
| model.train() |
| set_dropout(model, 0.1) |
| |
| total_loss = 0 |
| step = 0 |
| start_time = time.time() |
| |
| for batch_idx, batch in enumerate(train_loader): |
| if step >= config.MAX_STEPS: |
| break |
| |
| |
| input_ids = batch['input_ids'].to(rank) |
| labels = batch['labels'].to(rank) |
| attention_mask = batch['attention_mask'].to(rank) |
| |
| |
| optimizer.zero_grad() |
| |
| with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| logits, telemetry = model(input_ids) |
| |
| |
| loss = F.cross_entropy( |
| logits.view(-1, 2), |
| labels.view(-1), |
| reduction='mean' |
| ) |
| |
| |
| if config.USE_SAFETY_GATES: |
| negentropy = telemetry.get('negentropy', 0) |
| lz_complexity = telemetry.get('lz_complexity', 0) |
| symbiosis = telemetry.get('symbiosis', 0) |
| |
| |
| if (negentropy < config.NEGENTROPY_THRESHOLD or |
| lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or |
| symbiosis < config.SYMBIOSIS_THRESHOLD): |
| |
| safety_penalty = 10.0 |
| loss = loss + safety_penalty |
| |
| if rank == 0: |
| logger.warning(f"Safety gate triggered at step {step}!") |
| |
| |
| loss = loss / config.GRADIENT_ACCUMULATION_STEPS |
| |
| |
| loss.backward() |
| |
| |
| if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| |
| |
| optimizer.step() |
| scheduler.step() |
| |
| |
| if step % config.LOG_INTERVAL == 0: |
| |
| samples_per_sec = (config.BATCH_SIZE_PER_GPU * world_size * |
| config.LOG_INTERVAL) / (time.time() - start_time + 1e-7) |
| memory_allocated = torch.cuda.memory_allocated(rank) / (1024**3) |
| |
| log_training_stats( |
| step, loss.item() * config.GRADIENT_ACCUMULATION_STEPS, |
| telemetry, scheduler.get_last_lr()[0], samples_per_sec, |
| memory_allocated, rank |
| ) |
| |
| start_time = time.time() |
| |
| |
| if step % config.CHECKPOINT_INTERVAL == 0 and step > 0: |
| save_checkpoint( |
| model, optimizer, scheduler, step, |
| loss.item() * config.GRADIENT_ACCUMULATION_STEPS, |
| config, rank |
| ) |
| |
| step += 1 |
| total_loss += loss.item() * config.GRADIENT_ACCUMULATION_STEPS |
| |
| avg_loss = total_loss / max(step, 1) |
| return avg_loss, telemetry |
|
|
|
|
| def validate_model(model: FSDP, val_loader: DataLoader, config: MassiveScaleConfig, |
| rank: int) -> Tuple[float, Dict[str, float]]: |
| """Validate model performance.""" |
| model.eval() |
| set_dropout(model, 0.0) |
| |
| total_loss = 0 |
| total_samples = 0 |
| accumulated_telemetry = {} |
| |
| with torch.no_grad(): |
| for batch in val_loader: |
| if total_samples >= 1000: |
| break |
| |
| input_ids = batch['input_ids'].to(rank) |
| labels = batch['labels'].to(rank) |
| |
| with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| logits, telemetry = model(input_ids) |
| loss = F.cross_entropy( |
| logits.view(-1, 2), |
| labels.view(-1), |
| reduction='mean' |
| ) |
| |
| total_loss += loss.item() * input_ids.size(0) |
| total_samples += input_ids.size(0) |
| |
| |
| for key, value in telemetry.items(): |
| if key in accumulated_telemetry: |
| accumulated_telemetry[key] += value |
| else: |
| accumulated_telemetry[key] = value |
| |
| avg_loss = total_loss / max(total_samples, 1) |
| |
| |
| for key in accumulated_telemetry: |
| accumulated_telemetry[key] /= max(total_samples, 1) |
| |
| return avg_loss, accumulated_telemetry |
|
|
|
|
| def main_worker(rank: int, world_size: int, config: MassiveScaleConfig) -> None: |
| """Main training worker process.""" |
| |
| setup_distributed(rank, world_size) |
| |
| if rank == 0: |
| logger.info("π MASSIVE SCALE BITTRANSFORMERLM TRAINING INITIATED!") |
| logger.info(f"Target: {count_parameters(BitTransformerLM(**config.get_model_config())):,} parameters") |
| logger.info(f"Hardware: {world_size}x NVIDIA L4 GPUs") |
| logger.info(f"Configuration: {config.get_model_config()}") |
| |
| |
| train_dataset = WikiTextDataset("train", max_samples=config.MAX_SAMPLES, |
| max_length=config.MAX_SEQ_LEN, streaming=config.STREAMING) |
| val_dataset = WikiTextDataset("validation", max_samples=1000, |
| max_length=config.MAX_SEQ_LEN, streaming=False) |
| |
| |
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.BATCH_SIZE_PER_GPU, |
| sampler=train_sampler, |
| num_workers=4, |
| pin_memory=True |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=config.BATCH_SIZE_PER_GPU, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| |
| model = create_fsdp_model(config.get_model_config(), rank) |
| |
| if rank == 0: |
| param_count = count_parameters(model) |
| logger.info(f"β
Model created with {param_count:,} parameters ({param_count/1e9:.2f}B)") |
| |
| |
| benchmark_update = f""" |
| |
| ### π₯ LIVE RUN: 1.21B Parameter Training |
| **Status:** ACTIVE |
| **Started:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |
| **Parameters:** {param_count:,} ({param_count/1e9:.2f}B) |
| **Architecture:** d_model={config.D_MODEL}, layers={config.NUM_LAYERS}, heads={config.NUM_HEADS} |
| **Effective Batch Size:** {config.EFFECTIVE_BATCH_SIZE} |
| **Dataset:** WikiText-103 (streaming) |
| **Hardware:** 4x NVIDIA L4 GPUs with FSDP |
| |
| """ |
| with open('/data/Benchmarks.md', 'a') as f: |
| f.write(benchmark_update) |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.LEARNING_RATE, |
| weight_decay=config.WEIGHT_DECAY, |
| betas=(0.9, 0.95), |
| ) |
| |
| |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=config.LEARNING_RATE, |
| total_steps=config.MAX_STEPS, |
| pct_start=config.WARMUP_STEPS / config.MAX_STEPS, |
| anneal_strategy='cos', |
| ) |
| |
| if rank == 0: |
| logger.info("π― Starting training loop...") |
| |
| |
| try: |
| for epoch in range(100): |
| train_sampler.set_epoch(epoch) |
| |
| train_loss, train_telemetry = train_one_epoch( |
| model, train_loader, optimizer, scheduler, |
| config, epoch, rank, world_size |
| ) |
| |
| if rank == 0: |
| logger.info(f"π Epoch {epoch} completed - Average Loss: {train_loss:.4f}") |
| |
| |
| val_loss, val_telemetry = validate_model(model, val_loader, config, rank) |
| logger.info(f"π Validation Loss: {val_loss:.4f}") |
| |
| except KeyboardInterrupt: |
| if rank == 0: |
| logger.info("Training interrupted by user") |
| except Exception as e: |
| if rank == 0: |
| logger.error(f"Training failed with error: {e}") |
| raise |
| finally: |
| cleanup_distributed() |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| parser = argparse.ArgumentParser(description='BitTransformerLM Massive Scale Training') |
| parser.add_argument('--world-size', type=int, default=4, help='Number of GPUs') |
| parser.add_argument('--port', type=str, default='29500', help='Master port') |
| |
| args = parser.parse_args() |
| |
| config = MassiveScaleConfig() |
| |
| |
| if not torch.cuda.is_available(): |
| print("β CUDA not available! This script requires GPU training.") |
| sys.exit(1) |
| |
| if torch.cuda.device_count() < args.world_size: |
| print(f"β Only {torch.cuda.device_count()} GPUs available, but {args.world_size} requested") |
| sys.exit(1) |
| |
| print(f"π Launching massive scale training on {args.world_size} GPUs...") |
| print(f"π Target: 1.21 BILLION parameters") |
| print(f"π Dataset: WikiText-103 (full corpus)") |
| print(f"π₯ This is going to be EPIC!") |
| |
| |
| mp.spawn( |
| main_worker, |
| args=(args.world_size, config), |
| nprocs=args.world_size, |
| join=True |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |