| |
| """ |
| BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED |
| ================================================================= |
| |
| Fixed version that properly initializes 680M parameter model with all optimizations! |
| Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import logging |
| from datetime import datetime |
| from typing import Dict, Any, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| import datasets |
| from datasets import load_dataset |
| import numpy as np |
|
|
| |
| from bit_transformer.model import BitTransformerLM |
| from bit_transformer.bit_io import text_to_bits, bits_to_text |
| from bit_transformer.utils import set_dropout |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class OptimizedConfig: |
| """Optimized 680M parameter configuration with ALL BitTransformerLM features enabled.""" |
| |
| |
| D_MODEL = 1536 |
| NUM_LAYERS = 24 |
| NUM_HEADS = 24 |
| DIM_FEEDFORWARD = 6144 |
| MAX_SEQ_LEN = 2048 |
| |
| |
| BATCH_SIZE_PER_GPU = 1 |
| NUM_GPUS = 4 |
| TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS |
| GRADIENT_ACCUMULATION_STEPS = 32 |
| |
| LEARNING_RATE = 3e-4 |
| WEIGHT_DECAY = 0.01 |
| MAX_STEPS = 10000 |
| WARMUP_STEPS = 500 |
| |
| |
| USE_REVERSIBLE = True |
| USE_GRADIENT_CHECKPOINTING = True |
| USE_MIXED_PRECISION = True |
| USE_AUTOCAST = True |
| CHUNK_SIZE = None |
| FULL_ATTN_LOGGING = False |
| |
| |
| LAMBDA_K = 1.0 |
| LAMBDA_C = 1.0 |
| LAMBDA_S = 1.0 |
| NEGENTROPY_THRESHOLD = 0.2 |
| LZ_COMPLEXITY_THRESHOLD = 0.3 |
| SYMBIOSIS_THRESHOLD = 0.5 |
| |
| @classmethod |
| def get_model_config(cls) -> Dict[str, Any]: |
| """Get optimized model configuration.""" |
| return { |
| "d_model": cls.D_MODEL, |
| "nhead": cls.NUM_HEADS, |
| "num_layers": cls.NUM_LAYERS, |
| "dim_feedforward": cls.DIM_FEEDFORWARD, |
| "max_seq_len": cls.MAX_SEQ_LEN, |
| "lambda_K": cls.LAMBDA_K, |
| "lambda_C": cls.LAMBDA_C, |
| "lambda_S": cls.LAMBDA_S, |
| "reversible": cls.USE_REVERSIBLE, |
| "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, |
| "use_autocast": cls.USE_AUTOCAST, |
| "chunk_size": cls.CHUNK_SIZE, |
| "full_attn_logging": cls.FULL_ATTN_LOGGING, |
| } |
|
|
|
|
| class SimpleWikiTextDataset(torch.utils.data.Dataset): |
| """Simplified WikiText dataset for bit-level training.""" |
| |
| def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048): |
| self.max_length = max_length |
| |
| logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...") |
| dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) |
| |
| |
| texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples] |
| self.texts = texts |
| |
| logger.info(f"Loaded {len(self.texts)} text samples from {split}") |
| |
| def __len__(self) -> int: |
| return len(self.texts) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| text = self.texts[idx] |
| |
| try: |
| |
| bits = text_to_bits(text) |
| |
| |
| if len(bits) > self.max_length: |
| bits = bits[:self.max_length] |
| elif len(bits) < self.max_length: |
| bits = bits + [0] * (self.max_length - len(bits)) |
| |
| |
| input_bits = torch.tensor(bits[:-1], dtype=torch.long) |
| target_bits = torch.tensor(bits[1:], dtype=torch.long) |
| |
| return { |
| 'input_ids': input_bits, |
| 'labels': target_bits, |
| 'attention_mask': torch.ones_like(input_bits) |
| } |
| |
| except Exception as e: |
| logger.warning(f"Error processing text at index {idx}: {e}") |
| |
| fallback_bits = [0, 1] * (self.max_length // 2) |
| input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) |
| target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) |
| |
| return { |
| 'input_ids': input_bits, |
| 'labels': target_bits, |
| 'attention_mask': torch.ones_like(input_bits) |
| } |
|
|
|
|
| def create_optimized_model(config: OptimizedConfig) -> nn.Module: |
| """Create properly optimized BitTransformerLM model.""" |
| |
| |
| logger.info("ποΈ Creating optimized BitTransformerLM model...") |
| model_config = config.get_model_config() |
| |
| logger.info("Model configuration:") |
| for k, v in model_config.items(): |
| logger.info(f" {k}: {v}") |
| |
| model = BitTransformerLM(**model_config) |
| |
| |
| params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"β
Model created: {params:,} parameters ({params/1e6:.1f}M)") |
| |
| |
| if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS: |
| logger.info(f"π Setting up multi-GPU training on {config.NUM_GPUS} GPUs...") |
| |
| |
| model = model.cuda() |
| |
| |
| if config.NUM_GPUS > 1: |
| model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS))) |
| logger.info(f"β
DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}") |
| |
| else: |
| logger.warning("β οΈ Limited GPU availability - using single GPU or CPU") |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| |
| return model |
|
|
|
|
| def train_step(model: nn.Module, batch: Dict[str, torch.Tensor], |
| optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, |
| config: OptimizedConfig) -> tuple: |
| """Optimized training step with all BitTransformerLM features.""" |
| |
| model.train() |
| set_dropout(model, 0.1) |
| |
| |
| input_ids = batch['input_ids'].cuda(non_blocking=True) |
| labels = batch['labels'].cuda(non_blocking=True) |
| |
| |
| with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| outputs = model(input_ids) |
| |
| if isinstance(outputs, tuple): |
| logits, telemetry = outputs |
| else: |
| logits, telemetry = outputs, {} |
| |
| |
| loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean') |
| |
| |
| safety_penalty = 0.0 |
| if telemetry: |
| negentropy = telemetry.get('negentropy', 1.0) |
| lz_complexity = telemetry.get('lz_complexity', 1.0) |
| symbiosis = telemetry.get('symbiosis', 1.0) |
| |
| if (negentropy < config.NEGENTROPY_THRESHOLD or |
| lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or |
| symbiosis < config.SYMBIOSIS_THRESHOLD): |
| safety_penalty = 0.1 |
| loss = loss + safety_penalty |
| |
| |
| loss = loss / config.GRADIENT_ACCUMULATION_STEPS |
| |
| |
| scaler.scale(loss).backward() |
| |
| return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty |
|
|
|
|
| def main(): |
| """Main training function.""" |
| |
| logger.info("π OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!") |
| logger.info("=" * 60) |
| |
| config = OptimizedConfig() |
| |
| |
| if not torch.cuda.is_available(): |
| logger.error("β CUDA not available!") |
| return |
| |
| logger.info(f"π₯ Hardware: {torch.cuda.device_count()}x GPUs detected") |
| for i in range(torch.cuda.device_count()): |
| props = torch.cuda.get_device_properties(i) |
| logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)") |
| |
| |
| model = create_optimized_model(config) |
| |
| |
| logger.info("π Loading datasets...") |
| train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN) |
| val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN) |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.BATCH_SIZE_PER_GPU, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=config.BATCH_SIZE_PER_GPU, |
| shuffle=False, |
| num_workers=1, |
| pin_memory=True |
| ) |
| |
| |
| logger.info("βοΈ Setting up optimizer...") |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.LEARNING_RATE, |
| weight_decay=config.WEIGHT_DECAY, |
| betas=(0.9, 0.95) |
| ) |
| |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=config.LEARNING_RATE, |
| total_steps=config.MAX_STEPS, |
| pct_start=config.WARMUP_STEPS / config.MAX_STEPS, |
| ) |
| |
| scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION) |
| |
| |
| logger.info("π― Starting training...") |
| logger.info(f"Target steps: {config.MAX_STEPS}") |
| logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}") |
| |
| step = 0 |
| running_loss = 0.0 |
| start_time = time.time() |
| |
| for epoch in range(100): |
| for batch_idx, batch in enumerate(train_loader): |
| |
| loss, telemetry, safety_penalty = train_step( |
| model, batch, optimizer, scaler, config |
| ) |
| running_loss += loss |
| |
| |
| if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: |
| |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| |
| |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| step += 1 |
| |
| |
| if step % 10 == 0: |
| avg_loss = running_loss / 10 |
| elapsed = time.time() - start_time |
| samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed |
| memory_used = torch.cuda.max_memory_allocated() / (1024**3) |
| |
| logger.info( |
| f"Step {step:4d} | " |
| f"Loss: {avg_loss:.4f} | " |
| f"K: {telemetry.get('negentropy', 0):.3f} | " |
| f"C: {telemetry.get('lz_complexity', 0):.3f} | " |
| f"S: {telemetry.get('symbiosis', 0):.3f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
| f"Speed: {samples_per_sec:.1f} samp/s | " |
| f"Mem: {memory_used:.1f}GB" |
| + (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "") |
| ) |
| |
| running_loss = 0.0 |
| start_time = time.time() |
| |
| |
| if step % 100 == 0: |
| model.eval() |
| set_dropout(model, 0.0) |
| val_loss = 0 |
| |
| with torch.no_grad(): |
| for val_batch in val_loader: |
| val_input_ids = val_batch['input_ids'].cuda() |
| val_labels = val_batch['labels'].cuda() |
| |
| with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| val_outputs = model(val_input_ids) |
| if isinstance(val_outputs, tuple): |
| val_logits, _ = val_outputs |
| else: |
| val_logits = val_outputs |
| |
| val_loss += F.cross_entropy( |
| val_logits.view(-1, 2), |
| val_labels.view(-1) |
| ).item() |
| |
| val_loss /= len(val_loader) |
| logger.info(f"π Validation Loss: {val_loss:.4f}") |
| |
| |
| if step % 500 == 0: |
| checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| |
| torch.save({ |
| 'step': step, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'config': config.get_model_config(), |
| }, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt") |
| |
| logger.info(f"πΎ Checkpoint saved: step {step}") |
| |
| if step >= config.MAX_STEPS: |
| logger.info("π Training completed!") |
| return |
| |
| if step >= config.MAX_STEPS: |
| break |
|
|
|
|
| if __name__ == "__main__": |
| main() |