| | |
| | """ |
| | BitTransformerLM Massive Scale Training - SIMPLIFIED & OPTIMIZED |
| | ================================================================= |
| | |
| | Fixed version that properly initializes 680M parameter model with all optimizations! |
| | Uses DataParallel for multi-GPU instead of FSDP to avoid initialization issues. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import time |
| | import json |
| | import logging |
| | from datetime import datetime |
| | from typing import Dict, Any, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | import datasets |
| | from datasets import load_dataset |
| | import numpy as np |
| |
|
| | |
| | from bit_transformer.model import BitTransformerLM |
| | from bit_transformer.bit_io import text_to_bits, bits_to_text |
| | from bit_transformer.utils import set_dropout |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class OptimizedConfig: |
| | """Optimized 680M parameter configuration with ALL BitTransformerLM features enabled.""" |
| | |
| | |
| | D_MODEL = 1536 |
| | NUM_LAYERS = 24 |
| | NUM_HEADS = 24 |
| | DIM_FEEDFORWARD = 6144 |
| | MAX_SEQ_LEN = 2048 |
| | |
| | |
| | BATCH_SIZE_PER_GPU = 1 |
| | NUM_GPUS = 4 |
| | TOTAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS |
| | GRADIENT_ACCUMULATION_STEPS = 32 |
| | |
| | LEARNING_RATE = 3e-4 |
| | WEIGHT_DECAY = 0.01 |
| | MAX_STEPS = 10000 |
| | WARMUP_STEPS = 500 |
| | |
| | |
| | USE_REVERSIBLE = True |
| | USE_GRADIENT_CHECKPOINTING = True |
| | USE_MIXED_PRECISION = True |
| | USE_AUTOCAST = True |
| | CHUNK_SIZE = None |
| | FULL_ATTN_LOGGING = False |
| | |
| | |
| | LAMBDA_K = 1.0 |
| | LAMBDA_C = 1.0 |
| | LAMBDA_S = 1.0 |
| | NEGENTROPY_THRESHOLD = 0.2 |
| | LZ_COMPLEXITY_THRESHOLD = 0.3 |
| | SYMBIOSIS_THRESHOLD = 0.5 |
| | |
| | @classmethod |
| | def get_model_config(cls) -> Dict[str, Any]: |
| | """Get optimized model configuration.""" |
| | return { |
| | "d_model": cls.D_MODEL, |
| | "nhead": cls.NUM_HEADS, |
| | "num_layers": cls.NUM_LAYERS, |
| | "dim_feedforward": cls.DIM_FEEDFORWARD, |
| | "max_seq_len": cls.MAX_SEQ_LEN, |
| | "lambda_K": cls.LAMBDA_K, |
| | "lambda_C": cls.LAMBDA_C, |
| | "lambda_S": cls.LAMBDA_S, |
| | "reversible": cls.USE_REVERSIBLE, |
| | "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING, |
| | "use_autocast": cls.USE_AUTOCAST, |
| | "chunk_size": cls.CHUNK_SIZE, |
| | "full_attn_logging": cls.FULL_ATTN_LOGGING, |
| | } |
| |
|
| |
|
| | class SimpleWikiTextDataset(torch.utils.data.Dataset): |
| | """Simplified WikiText dataset for bit-level training.""" |
| | |
| | def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 2048): |
| | self.max_length = max_length |
| | |
| | logger.info(f"Loading WikiText-103 {split} split (max {max_samples} samples)...") |
| | dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split) |
| | |
| | |
| | texts = [item['text'] for item in dataset if len(item['text'].strip()) > 100][:max_samples] |
| | self.texts = texts |
| | |
| | logger.info(f"Loaded {len(self.texts)} text samples from {split}") |
| | |
| | def __len__(self) -> int: |
| | return len(self.texts) |
| | |
| | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| | text = self.texts[idx] |
| | |
| | try: |
| | |
| | bits = text_to_bits(text) |
| | |
| | |
| | if len(bits) > self.max_length: |
| | bits = bits[:self.max_length] |
| | elif len(bits) < self.max_length: |
| | bits = bits + [0] * (self.max_length - len(bits)) |
| | |
| | |
| | input_bits = torch.tensor(bits[:-1], dtype=torch.long) |
| | target_bits = torch.tensor(bits[1:], dtype=torch.long) |
| | |
| | return { |
| | 'input_ids': input_bits, |
| | 'labels': target_bits, |
| | 'attention_mask': torch.ones_like(input_bits) |
| | } |
| | |
| | except Exception as e: |
| | logger.warning(f"Error processing text at index {idx}: {e}") |
| | |
| | fallback_bits = [0, 1] * (self.max_length // 2) |
| | input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long) |
| | target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long) |
| | |
| | return { |
| | 'input_ids': input_bits, |
| | 'labels': target_bits, |
| | 'attention_mask': torch.ones_like(input_bits) |
| | } |
| |
|
| |
|
| | def create_optimized_model(config: OptimizedConfig) -> nn.Module: |
| | """Create properly optimized BitTransformerLM model.""" |
| | |
| | |
| | logger.info("๐๏ธ Creating optimized BitTransformerLM model...") |
| | model_config = config.get_model_config() |
| | |
| | logger.info("Model configuration:") |
| | for k, v in model_config.items(): |
| | logger.info(f" {k}: {v}") |
| | |
| | model = BitTransformerLM(**model_config) |
| | |
| | |
| | params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | logger.info(f"โ
Model created: {params:,} parameters ({params/1e6:.1f}M)") |
| | |
| | |
| | if torch.cuda.is_available() and torch.cuda.device_count() >= config.NUM_GPUS: |
| | logger.info(f"๐ Setting up multi-GPU training on {config.NUM_GPUS} GPUs...") |
| | |
| | |
| | model = model.cuda() |
| | |
| | |
| | if config.NUM_GPUS > 1: |
| | model = nn.DataParallel(model, device_ids=list(range(config.NUM_GPUS))) |
| | logger.info(f"โ
DataParallel setup complete across GPUs: {list(range(config.NUM_GPUS))}") |
| | |
| | else: |
| | logger.warning("โ ๏ธ Limited GPU availability - using single GPU or CPU") |
| | if torch.cuda.is_available(): |
| | model = model.cuda() |
| | |
| | return model |
| |
|
| |
|
| | def train_step(model: nn.Module, batch: Dict[str, torch.Tensor], |
| | optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler, |
| | config: OptimizedConfig) -> tuple: |
| | """Optimized training step with all BitTransformerLM features.""" |
| | |
| | model.train() |
| | set_dropout(model, 0.1) |
| | |
| | |
| | input_ids = batch['input_ids'].cuda(non_blocking=True) |
| | labels = batch['labels'].cuda(non_blocking=True) |
| | |
| | |
| | with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| | outputs = model(input_ids) |
| | |
| | if isinstance(outputs, tuple): |
| | logits, telemetry = outputs |
| | else: |
| | logits, telemetry = outputs, {} |
| | |
| | |
| | loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction='mean') |
| | |
| | |
| | safety_penalty = 0.0 |
| | if telemetry: |
| | negentropy = telemetry.get('negentropy', 1.0) |
| | lz_complexity = telemetry.get('lz_complexity', 1.0) |
| | symbiosis = telemetry.get('symbiosis', 1.0) |
| | |
| | if (negentropy < config.NEGENTROPY_THRESHOLD or |
| | lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or |
| | symbiosis < config.SYMBIOSIS_THRESHOLD): |
| | safety_penalty = 0.1 |
| | loss = loss + safety_penalty |
| | |
| | |
| | loss = loss / config.GRADIENT_ACCUMULATION_STEPS |
| | |
| | |
| | scaler.scale(loss).backward() |
| | |
| | return loss.item() * config.GRADIENT_ACCUMULATION_STEPS, telemetry, safety_penalty |
| |
|
| |
|
| | def main(): |
| | """Main training function.""" |
| | |
| | logger.info("๐ OPTIMIZED MASSIVE SCALE BITTRANSFORMERLM TRAINING!") |
| | logger.info("=" * 60) |
| | |
| | config = OptimizedConfig() |
| | |
| | |
| | if not torch.cuda.is_available(): |
| | logger.error("โ CUDA not available!") |
| | return |
| | |
| | logger.info(f"๐ฅ Hardware: {torch.cuda.device_count()}x GPUs detected") |
| | for i in range(torch.cuda.device_count()): |
| | props = torch.cuda.get_device_properties(i) |
| | logger.info(f" GPU {i}: {props.name} ({props.total_memory / 1024**3:.1f}GB)") |
| | |
| | |
| | model = create_optimized_model(config) |
| | |
| | |
| | logger.info("๐ Loading datasets...") |
| | train_dataset = SimpleWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN) |
| | val_dataset = SimpleWikiTextDataset("validation", max_samples=100, max_length=config.MAX_SEQ_LEN) |
| | |
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=config.BATCH_SIZE_PER_GPU, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True |
| | ) |
| | |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=config.BATCH_SIZE_PER_GPU, |
| | shuffle=False, |
| | num_workers=1, |
| | pin_memory=True |
| | ) |
| | |
| | |
| | logger.info("โ๏ธ Setting up optimizer...") |
| | optimizer = torch.optim.AdamW( |
| | model.parameters(), |
| | lr=config.LEARNING_RATE, |
| | weight_decay=config.WEIGHT_DECAY, |
| | betas=(0.9, 0.95) |
| | ) |
| | |
| | scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=config.LEARNING_RATE, |
| | total_steps=config.MAX_STEPS, |
| | pct_start=config.WARMUP_STEPS / config.MAX_STEPS, |
| | ) |
| | |
| | scaler = torch.cuda.amp.GradScaler(enabled=config.USE_MIXED_PRECISION) |
| | |
| | |
| | logger.info("๐ฏ Starting training...") |
| | logger.info(f"Target steps: {config.MAX_STEPS}") |
| | logger.info(f"Effective batch size: {config.TOTAL_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}") |
| | |
| | step = 0 |
| | running_loss = 0.0 |
| | start_time = time.time() |
| | |
| | for epoch in range(100): |
| | for batch_idx, batch in enumerate(train_loader): |
| | |
| | loss, telemetry, safety_penalty = train_step( |
| | model, batch, optimizer, scaler, config |
| | ) |
| | running_loss += loss |
| | |
| | |
| | if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0: |
| | |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | |
| | |
| | scaler.step(optimizer) |
| | scaler.update() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | |
| | step += 1 |
| | |
| | |
| | if step % 10 == 0: |
| | avg_loss = running_loss / 10 |
| | elapsed = time.time() - start_time |
| | samples_per_sec = (config.TOTAL_BATCH_SIZE * 10) / elapsed |
| | memory_used = torch.cuda.max_memory_allocated() / (1024**3) |
| | |
| | logger.info( |
| | f"Step {step:4d} | " |
| | f"Loss: {avg_loss:.4f} | " |
| | f"K: {telemetry.get('negentropy', 0):.3f} | " |
| | f"C: {telemetry.get('lz_complexity', 0):.3f} | " |
| | f"S: {telemetry.get('symbiosis', 0):.3f} | " |
| | f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
| | f"Speed: {samples_per_sec:.1f} samp/s | " |
| | f"Mem: {memory_used:.1f}GB" |
| | + (f" | Safety: {safety_penalty:.3f}" if safety_penalty > 0 else "") |
| | ) |
| | |
| | running_loss = 0.0 |
| | start_time = time.time() |
| | |
| | |
| | if step % 100 == 0: |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | val_loss = 0 |
| | |
| | with torch.no_grad(): |
| | for val_batch in val_loader: |
| | val_input_ids = val_batch['input_ids'].cuda() |
| | val_labels = val_batch['labels'].cuda() |
| | |
| | with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION): |
| | val_outputs = model(val_input_ids) |
| | if isinstance(val_outputs, tuple): |
| | val_logits, _ = val_outputs |
| | else: |
| | val_logits = val_outputs |
| | |
| | val_loss += F.cross_entropy( |
| | val_logits.view(-1, 2), |
| | val_labels.view(-1) |
| | ).item() |
| | |
| | val_loss /= len(val_loader) |
| | logger.info(f"๐ Validation Loss: {val_loss:.4f}") |
| | |
| | |
| | if step % 500 == 0: |
| | checkpoint_dir = f"/data/checkpoints/massive_simple_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| | |
| | torch.save({ |
| | 'step': step, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scheduler_state_dict': scheduler.state_dict(), |
| | 'config': config.get_model_config(), |
| | }, f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt") |
| | |
| | logger.info(f"๐พ Checkpoint saved: step {step}") |
| | |
| | if step >= config.MAX_STEPS: |
| | logger.info("๐ Training completed!") |
| | return |
| | |
| | if step >= config.MAX_STEPS: |
| | break |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |