| """ |
| VicAI Training Script |
| Distributed training with FSDP/DDP support. |
| """ |
|
|
| import argparse |
| import os |
| import time |
| from contextlib import nullcontext |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
|
|
| from model import VicAIModel, VicAIConfig, create_vicai_5b |
| from tokenizer import ByteLevelBPETokenizer, BPETokenizer |
| from dataset import ( |
| WikipediaDataset, |
| TextFileDataset, |
| MixedDataset, |
| create_sample_corpus, |
| ) |
| from utils import ( |
| get_logger, |
| load_checkpoint, |
| save_checkpoint, |
| get_lr_scheduler, |
| estimate_loss, |
| configure_optimizers, |
| ) |
|
|
|
|
| def setup_distributed(): |
| """Initialize distributed training.""" |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| rank = int(os.environ['RANK']) |
| world_size = int(os.environ['WORLD_SIZE']) |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| else: |
| rank = 0 |
| world_size = 1 |
| local_rank = 0 |
| |
| if world_size > 1: |
| dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) |
| torch.cuda.set_device(local_rank) |
| |
| return rank, world_size, local_rank |
|
|
|
|
| def cleanup_distributed(): |
| """Cleanup distributed training.""" |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def get_data_loader(dataset, batch_size, world_size, rank, shuffle=True): |
| """Create distributed data loader.""" |
| if world_size > 1: |
| sampler = DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=shuffle, |
| ) |
| else: |
| sampler = None |
| |
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| sampler=sampler, |
| num_workers=4, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| |
| return loader, sampler |
|
|
|
|
| def train_step(model, batch, optimizer, scaler, device, use_amp): |
| """Single training step.""" |
| model.train() |
| |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| |
| optimizer.zero_grad() |
| |
| with torch.cuda.amp.autocast(enabled=use_amp): |
| outputs = model(input_ids, targets=labels) |
| loss = outputs['loss'] |
| |
| if use_amp: |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| optimizer.step() |
| |
| return loss.item() |
|
|
|
|
| def train( |
| model, |
| train_loader, |
| val_loader, |
| optimizer, |
| lr_scheduler, |
| scaler, |
| device, |
| args, |
| logger, |
| ): |
| """Main training loop.""" |
| best_val_loss = float('inf') |
| step = 0 |
| |
| model.train() |
| train_iterator = iter(train_loader) |
| |
| for epoch in range(args.max_epochs): |
| if hasattr(train_loader.sampler, 'set_epoch'): |
| train_loader.sampler.set_epoch(epoch) |
| |
| epoch_start_time = time.time() |
| |
| while step < args.max_steps: |
| try: |
| batch = next(train_iterator) |
| except StopIteration: |
| train_iterator = iter(train_loader) |
| batch = next(train_iterator) |
| |
| |
| loss = train_step(model, batch, optimizer, scaler, device, args.use_amp) |
| lr_scheduler.step() |
| |
| step += 1 |
| |
| |
| if step % args.log_interval == 0 and args.rank == 0: |
| lr = optimizer.param_groups[0]['lr'] |
| logger.info( |
| f"Step {step}/{args.max_steps} | " |
| f"Loss: {loss:.4f} | LR: {lr:.2e}" |
| ) |
| |
| |
| if step % args.eval_interval == 0: |
| val_loss = evaluate(model, val_loader, device, args.use_amp) |
| |
| if args.rank == 0: |
| logger.info(f"Validation loss: {val_loss:.4f}") |
| |
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint( |
| model, |
| optimizer, |
| scaler, |
| step, |
| val_loss, |
| args.output_dir / 'best_model.pt', |
| ) |
| logger.info(f"Saved best model with loss {val_loss:.4f}") |
| |
| model.train() |
| |
| |
| if step % args.save_interval == 0 and args.rank == 0: |
| save_checkpoint( |
| model, |
| optimizer, |
| scaler, |
| step, |
| loss, |
| args.output_dir / f'checkpoint_step_{step}.pt', |
| ) |
| logger.info(f"Saved checkpoint at step {step}") |
| |
| if step >= args.max_steps: |
| break |
| |
| epoch_time = time.time() - epoch_start_time |
| if args.rank == 0: |
| logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f}s") |
| |
| return step |
|
|
|
|
| def evaluate(model, data_loader, device, use_amp): |
| """Evaluate model on validation set.""" |
| model.eval() |
| total_loss = 0 |
| num_batches = 0 |
| |
| with torch.no_grad(): |
| for batch in data_loader: |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| |
| with torch.cuda.amp.autocast(enabled=use_amp): |
| outputs = model(input_ids, targets=labels) |
| loss = outputs['loss'] |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| if num_batches >= 100: |
| break |
| |
| |
| avg_loss = total_loss / num_batches |
| |
| if dist.is_initialized(): |
| loss_tensor = torch.tensor([avg_loss], device=device) |
| dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) |
| avg_loss = loss_tensor.item() |
| |
| return avg_loss |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Train VicAI') |
| |
| |
| parser.add_argument('--vocab-size', type=int, default=32000) |
| parser.add_argument('--dim', type=int, default=4096) |
| parser.add_argument('--n-layers', type=int, default=32) |
| parser.add_argument('--n-heads', type=int, default=32) |
| parser.add_argument('--n-kv-heads', type=int, default=8) |
| parser.add_argument('--hidden-dim', type=int, default=14336) |
| |
| |
| parser.add_argument('--batch-size', type=int, default=4) |
| parser.add_argument('--max-seq-len', type=int, default=2048) |
| parser.add_argument('--max-steps', type=int, default=100000) |
| parser.add_argument('--max-epochs', type=int, default=10) |
| parser.add_argument('--learning-rate', type=float, default=3e-4) |
| parser.add_argument('--min-lr', type=float, default=3e-5) |
| parser.add_argument('--warmup-steps', type=int, default=2000) |
| parser.add_argument('--weight-decay', type=float, default=0.1) |
| parser.add_argument('--grad-clip', type=float, default=1.0) |
| parser.add_argument('--beta1', type=float, default=0.9) |
| parser.add_argument('--beta2', type=float, default=0.95) |
| |
| |
| parser.add_argument('--train-data', type=str, default='data/train.txt') |
| parser.add_argument('--val-data', type=str, default='data/val.txt') |
| parser.add_argument('--tokenizer-path', type=str, default='tokenizer.pkl') |
| |
| |
| parser.add_argument('--output-dir', type=str, default='checkpoints') |
| parser.add_argument('--resume', type=str, default=None) |
| parser.add_argument('--eval-interval', type=int, default=1000) |
| parser.add_argument('--save-interval', type=int, default=5000) |
| parser.add_argument('--log-interval', type=int, default=100) |
| parser.add_argument('--use-amp', action='store_true', default=True) |
| parser.add_argument('--use-fsdp', action='store_true', default=False) |
| parser.add_argument('--compile', action='store_true', default=False) |
| |
| args = parser.parse_args() |
| |
| |
| args.rank, args.world_size, args.local_rank = setup_distributed() |
| args.is_distributed = args.world_size > 1 |
| |
| |
| args.output_dir = Path(args.output_dir) |
| if args.rank == 0: |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| logger = get_logger('vicai_train', args.output_dir / 'train.log' if args.rank == 0 else None) |
| |
| if args.rank == 0: |
| logger.info(f"Starting VicAI training with {args.world_size} GPUs") |
| logger.info(f"Arguments: {args}") |
| |
| |
| device = torch.device(f'cuda:{args.local_rank}' if torch.cuda.is_available() else 'cpu') |
| |
| |
| if os.path.exists(args.tokenizer_path): |
| logger.info(f"Loading tokenizer from {args.tokenizer_path}") |
| tokenizer = ByteLevelBPETokenizer() |
| tokenizer.load(args.tokenizer_path) |
| else: |
| logger.warning(f"Tokenizer not found at {args.tokenizer_path}, creating default") |
| tokenizer = ByteLevelBPETokenizer(vocab_size=args.vocab_size) |
| |
| if args.rank == 0: |
| sample_file = create_sample_corpus(num_articles=100) |
| with open(sample_file, 'r') as f: |
| texts = f.read().split('<|endoftext|>') |
| tokenizer.train([t for t in texts if t.strip()]) |
| tokenizer.save(args.tokenizer_path) |
| |
| if args.is_distributed: |
| dist.barrier() |
| |
| if args.rank != 0: |
| tokenizer.load(args.tokenizer_path) |
| |
| |
| logger.info("Creating model...") |
| config = VicAIConfig( |
| vocab_size=len(tokenizer), |
| dim=args.dim, |
| n_layers=args.n_layers, |
| n_heads=args.n_heads, |
| n_kv_heads=args.n_kv_heads, |
| hidden_dim=args.hidden_dim, |
| max_seq_len=args.max_seq_len, |
| dropout=0.0, |
| ) |
| |
| if args.rank == 0: |
| logger.info(f"Model config: {config.__dict__}") |
| logger.info(f"Model parameters: ~{config.num_parameters / 1e9:.2f}B") |
| |
| model = VicAIModel(config) |
| |
| if args.use_fsdp and args.is_distributed: |
| model = FSDP(model, device_id=device) |
| elif args.is_distributed: |
| model = DDP(model, device_ids=[args.local_rank]) |
| else: |
| model = model.to(device) |
| |
| if args.compile and hasattr(torch, 'compile'): |
| logger.info("Compiling model...") |
| model = torch.compile(model) |
| |
| |
| logger.info("Creating datasets...") |
| |
| if os.path.exists(args.train_data): |
| train_dataset = TextFileDataset(args.train_data, tokenizer, args.max_seq_len) |
| val_dataset = TextFileDataset(args.val_data, tokenizer, args.max_seq_len) if os.path.exists(args.val_data) else train_dataset |
| else: |
| logger.warning("Training data not found, using Wikipedia streaming dataset") |
| train_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len) |
| val_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len) |
| |
| train_loader, train_sampler = get_data_loader(train_dataset, args.batch_size, args.world_size, args.rank) |
| val_loader, _ = get_data_loader(val_dataset, args.batch_size, args.world_size, args.rank, shuffle=False) |
| |
| |
| optimizer = configure_optimizers(model, args) |
| |
| |
| lr_scheduler = get_lr_scheduler(optimizer, args) |
| |
| |
| scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp) |
| |
| |
| start_step = 0 |
| if args.resume: |
| logger.info(f"Resuming from {args.resume}") |
| start_step = load_checkpoint(model, optimizer, scaler, args.resume, device) |
| |
| |
| logger.info("Starting training...") |
| final_step = train( |
| model, |
| train_loader, |
| val_loader, |
| optimizer, |
| lr_scheduler, |
| scaler, |
| device, |
| args, |
| logger, |
| ) |
| |
| |
| if args.rank == 0: |
| save_checkpoint( |
| model, |
| optimizer, |
| scaler, |
| final_step, |
| 0.0, |
| args.output_dir / 'final_model.pt', |
| ) |
| logger.info("Training completed!") |
| |
| cleanup_distributed() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|