""" VicAI Training Script Distributed training with FSDP/DDP support. """ import argparse import os import time from contextlib import nullcontext from pathlib import Path import torch import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from model import VicAIModel, VicAIConfig, create_vicai_5b from tokenizer import ByteLevelBPETokenizer, BPETokenizer from dataset import ( WikipediaDataset, TextFileDataset, MixedDataset, create_sample_corpus, ) from utils import ( get_logger, load_checkpoint, save_checkpoint, get_lr_scheduler, estimate_loss, configure_optimizers, ) def setup_distributed(): """Initialize distributed training.""" if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) local_rank = int(os.environ.get('LOCAL_RANK', 0)) else: rank = 0 world_size = 1 local_rank = 0 if world_size > 1: dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) torch.cuda.set_device(local_rank) return rank, world_size, local_rank def cleanup_distributed(): """Cleanup distributed training.""" if dist.is_initialized(): dist.destroy_process_group() def get_data_loader(dataset, batch_size, world_size, rank, shuffle=True): """Create distributed data loader.""" if world_size > 1: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, ) else: sampler = None loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=4, pin_memory=True, drop_last=True, ) return loader, sampler def train_step(model, batch, optimizer, scaler, device, use_amp): """Single training step.""" model.train() input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=use_amp): outputs = model(input_ids, targets=labels) loss = outputs['loss'] if use_amp: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() return loss.item() def train( model, train_loader, val_loader, optimizer, lr_scheduler, scaler, device, args, logger, ): """Main training loop.""" best_val_loss = float('inf') step = 0 model.train() train_iterator = iter(train_loader) for epoch in range(args.max_epochs): if hasattr(train_loader.sampler, 'set_epoch'): train_loader.sampler.set_epoch(epoch) epoch_start_time = time.time() while step < args.max_steps: try: batch = next(train_iterator) except StopIteration: train_iterator = iter(train_loader) batch = next(train_iterator) # Training step loss = train_step(model, batch, optimizer, scaler, device, args.use_amp) lr_scheduler.step() step += 1 # Logging if step % args.log_interval == 0 and args.rank == 0: lr = optimizer.param_groups[0]['lr'] logger.info( f"Step {step}/{args.max_steps} | " f"Loss: {loss:.4f} | LR: {lr:.2e}" ) # Evaluation if step % args.eval_interval == 0: val_loss = evaluate(model, val_loader, device, args.use_amp) if args.rank == 0: logger.info(f"Validation loss: {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint( model, optimizer, scaler, step, val_loss, args.output_dir / 'best_model.pt', ) logger.info(f"Saved best model with loss {val_loss:.4f}") model.train() # Regular checkpointing if step % args.save_interval == 0 and args.rank == 0: save_checkpoint( model, optimizer, scaler, step, loss, args.output_dir / f'checkpoint_step_{step}.pt', ) logger.info(f"Saved checkpoint at step {step}") if step >= args.max_steps: break epoch_time = time.time() - epoch_start_time if args.rank == 0: logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f}s") return step def evaluate(model, data_loader, device, use_amp): """Evaluate model on validation set.""" model.eval() total_loss = 0 num_batches = 0 with torch.no_grad(): for batch in data_loader: input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) with torch.cuda.amp.autocast(enabled=use_amp): outputs = model(input_ids, targets=labels) loss = outputs['loss'] total_loss += loss.item() num_batches += 1 if num_batches >= 100: # Limit eval batches break # Average across all processes avg_loss = total_loss / num_batches if dist.is_initialized(): loss_tensor = torch.tensor([avg_loss], device=device) dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) avg_loss = loss_tensor.item() return avg_loss def main(): parser = argparse.ArgumentParser(description='Train VicAI') # Model args parser.add_argument('--vocab-size', type=int, default=32000) parser.add_argument('--dim', type=int, default=4096) parser.add_argument('--n-layers', type=int, default=32) parser.add_argument('--n-heads', type=int, default=32) parser.add_argument('--n-kv-heads', type=int, default=8) parser.add_argument('--hidden-dim', type=int, default=14336) # Training args parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--max-seq-len', type=int, default=2048) parser.add_argument('--max-steps', type=int, default=100000) parser.add_argument('--max-epochs', type=int, default=10) parser.add_argument('--learning-rate', type=float, default=3e-4) parser.add_argument('--min-lr', type=float, default=3e-5) parser.add_argument('--warmup-steps', type=int, default=2000) parser.add_argument('--weight-decay', type=float, default=0.1) parser.add_argument('--grad-clip', type=float, default=1.0) parser.add_argument('--beta1', type=float, default=0.9) parser.add_argument('--beta2', type=float, default=0.95) # Data args parser.add_argument('--train-data', type=str, default='data/train.txt') parser.add_argument('--val-data', type=str, default='data/val.txt') parser.add_argument('--tokenizer-path', type=str, default='tokenizer.pkl') # System args parser.add_argument('--output-dir', type=str, default='checkpoints') parser.add_argument('--resume', type=str, default=None) parser.add_argument('--eval-interval', type=int, default=1000) parser.add_argument('--save-interval', type=int, default=5000) parser.add_argument('--log-interval', type=int, default=100) parser.add_argument('--use-amp', action='store_true', default=True) parser.add_argument('--use-fsdp', action='store_true', default=False) parser.add_argument('--compile', action='store_true', default=False) args = parser.parse_args() # Setup args.rank, args.world_size, args.local_rank = setup_distributed() args.is_distributed = args.world_size > 1 # Create output directory args.output_dir = Path(args.output_dir) if args.rank == 0: args.output_dir.mkdir(parents=True, exist_ok=True) # Logger logger = get_logger('vicai_train', args.output_dir / 'train.log' if args.rank == 0 else None) if args.rank == 0: logger.info(f"Starting VicAI training with {args.world_size} GPUs") logger.info(f"Arguments: {args}") # Device device = torch.device(f'cuda:{args.local_rank}' if torch.cuda.is_available() else 'cpu') # Load tokenizer if os.path.exists(args.tokenizer_path): logger.info(f"Loading tokenizer from {args.tokenizer_path}") tokenizer = ByteLevelBPETokenizer() tokenizer.load(args.tokenizer_path) else: logger.warning(f"Tokenizer not found at {args.tokenizer_path}, creating default") tokenizer = ByteLevelBPETokenizer(vocab_size=args.vocab_size) # Train on sample data if args.rank == 0: sample_file = create_sample_corpus(num_articles=100) with open(sample_file, 'r') as f: texts = f.read().split('<|endoftext|>') tokenizer.train([t for t in texts if t.strip()]) tokenizer.save(args.tokenizer_path) if args.is_distributed: dist.barrier() if args.rank != 0: tokenizer.load(args.tokenizer_path) # Create model logger.info("Creating model...") config = VicAIConfig( vocab_size=len(tokenizer), dim=args.dim, n_layers=args.n_layers, n_heads=args.n_heads, n_kv_heads=args.n_kv_heads, hidden_dim=args.hidden_dim, max_seq_len=args.max_seq_len, dropout=0.0, ) if args.rank == 0: logger.info(f"Model config: {config.__dict__}") logger.info(f"Model parameters: ~{config.num_parameters / 1e9:.2f}B") model = VicAIModel(config) if args.use_fsdp and args.is_distributed: model = FSDP(model, device_id=device) elif args.is_distributed: model = DDP(model, device_ids=[args.local_rank]) else: model = model.to(device) if args.compile and hasattr(torch, 'compile'): logger.info("Compiling model...") model = torch.compile(model) # Create datasets logger.info("Creating datasets...") if os.path.exists(args.train_data): train_dataset = TextFileDataset(args.train_data, tokenizer, args.max_seq_len) val_dataset = TextFileDataset(args.val_data, tokenizer, args.max_seq_len) if os.path.exists(args.val_data) else train_dataset else: logger.warning("Training data not found, using Wikipedia streaming dataset") train_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len) val_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len) train_loader, train_sampler = get_data_loader(train_dataset, args.batch_size, args.world_size, args.rank) val_loader, _ = get_data_loader(val_dataset, args.batch_size, args.world_size, args.rank, shuffle=False) # Optimizer optimizer = configure_optimizers(model, args) # Learning rate scheduler lr_scheduler = get_lr_scheduler(optimizer, args) # Gradient scaler for AMP scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) # Resume from checkpoint start_step = 0 if args.resume: logger.info(f"Resuming from {args.resume}") start_step = load_checkpoint(model, optimizer, scaler, args.resume, device) # Training logger.info("Starting training...") final_step = train( model, train_loader, val_loader, optimizer, lr_scheduler, scaler, device, args, logger, ) # Save final model if args.rank == 0: save_checkpoint( model, optimizer, scaler, final_step, 0.0, args.output_dir / 'final_model.pt', ) logger.info("Training completed!") cleanup_distributed() if __name__ == '__main__': main()