tinyvic / train.py
Viclim's picture
Upload 17 files
9299fff verified
"""
VicAI Training Script
Distributed training with FSDP/DDP support.
"""
import argparse
import os
import time
from contextlib import nullcontext
from pathlib import Path
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from model import VicAIModel, VicAIConfig, create_vicai_5b
from tokenizer import ByteLevelBPETokenizer, BPETokenizer
from dataset import (
WikipediaDataset,
TextFileDataset,
MixedDataset,
create_sample_corpus,
)
from utils import (
get_logger,
load_checkpoint,
save_checkpoint,
get_lr_scheduler,
estimate_loss,
configure_optimizers,
)
def setup_distributed():
"""Initialize distributed training."""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ.get('LOCAL_RANK', 0))
else:
rank = 0
world_size = 1
local_rank = 0
if world_size > 1:
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def cleanup_distributed():
"""Cleanup distributed training."""
if dist.is_initialized():
dist.destroy_process_group()
def get_data_loader(dataset, batch_size, world_size, rank, shuffle=True):
"""Create distributed data loader."""
if world_size > 1:
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
)
else:
sampler = None
loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True,
drop_last=True,
)
return loader, sampler
def train_step(model, batch, optimizer, scaler, device, use_amp):
"""Single training step."""
model.train()
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=use_amp):
outputs = model(input_ids, targets=labels)
loss = outputs['loss']
if use_amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
return loss.item()
def train(
model,
train_loader,
val_loader,
optimizer,
lr_scheduler,
scaler,
device,
args,
logger,
):
"""Main training loop."""
best_val_loss = float('inf')
step = 0
model.train()
train_iterator = iter(train_loader)
for epoch in range(args.max_epochs):
if hasattr(train_loader.sampler, 'set_epoch'):
train_loader.sampler.set_epoch(epoch)
epoch_start_time = time.time()
while step < args.max_steps:
try:
batch = next(train_iterator)
except StopIteration:
train_iterator = iter(train_loader)
batch = next(train_iterator)
# Training step
loss = train_step(model, batch, optimizer, scaler, device, args.use_amp)
lr_scheduler.step()
step += 1
# Logging
if step % args.log_interval == 0 and args.rank == 0:
lr = optimizer.param_groups[0]['lr']
logger.info(
f"Step {step}/{args.max_steps} | "
f"Loss: {loss:.4f} | LR: {lr:.2e}"
)
# Evaluation
if step % args.eval_interval == 0:
val_loss = evaluate(model, val_loader, device, args.use_amp)
if args.rank == 0:
logger.info(f"Validation loss: {val_loss:.4f}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(
model,
optimizer,
scaler,
step,
val_loss,
args.output_dir / 'best_model.pt',
)
logger.info(f"Saved best model with loss {val_loss:.4f}")
model.train()
# Regular checkpointing
if step % args.save_interval == 0 and args.rank == 0:
save_checkpoint(
model,
optimizer,
scaler,
step,
loss,
args.output_dir / f'checkpoint_step_{step}.pt',
)
logger.info(f"Saved checkpoint at step {step}")
if step >= args.max_steps:
break
epoch_time = time.time() - epoch_start_time
if args.rank == 0:
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f}s")
return step
def evaluate(model, data_loader, device, use_amp):
"""Evaluate model on validation set."""
model.eval()
total_loss = 0
num_batches = 0
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
with torch.cuda.amp.autocast(enabled=use_amp):
outputs = model(input_ids, targets=labels)
loss = outputs['loss']
total_loss += loss.item()
num_batches += 1
if num_batches >= 100: # Limit eval batches
break
# Average across all processes
avg_loss = total_loss / num_batches
if dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=device)
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
return avg_loss
def main():
parser = argparse.ArgumentParser(description='Train VicAI')
# Model args
parser.add_argument('--vocab-size', type=int, default=32000)
parser.add_argument('--dim', type=int, default=4096)
parser.add_argument('--n-layers', type=int, default=32)
parser.add_argument('--n-heads', type=int, default=32)
parser.add_argument('--n-kv-heads', type=int, default=8)
parser.add_argument('--hidden-dim', type=int, default=14336)
# Training args
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--max-seq-len', type=int, default=2048)
parser.add_argument('--max-steps', type=int, default=100000)
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--learning-rate', type=float, default=3e-4)
parser.add_argument('--min-lr', type=float, default=3e-5)
parser.add_argument('--warmup-steps', type=int, default=2000)
parser.add_argument('--weight-decay', type=float, default=0.1)
parser.add_argument('--grad-clip', type=float, default=1.0)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.95)
# Data args
parser.add_argument('--train-data', type=str, default='data/train.txt')
parser.add_argument('--val-data', type=str, default='data/val.txt')
parser.add_argument('--tokenizer-path', type=str, default='tokenizer.pkl')
# System args
parser.add_argument('--output-dir', type=str, default='checkpoints')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--eval-interval', type=int, default=1000)
parser.add_argument('--save-interval', type=int, default=5000)
parser.add_argument('--log-interval', type=int, default=100)
parser.add_argument('--use-amp', action='store_true', default=True)
parser.add_argument('--use-fsdp', action='store_true', default=False)
parser.add_argument('--compile', action='store_true', default=False)
args = parser.parse_args()
# Setup
args.rank, args.world_size, args.local_rank = setup_distributed()
args.is_distributed = args.world_size > 1
# Create output directory
args.output_dir = Path(args.output_dir)
if args.rank == 0:
args.output_dir.mkdir(parents=True, exist_ok=True)
# Logger
logger = get_logger('vicai_train', args.output_dir / 'train.log' if args.rank == 0 else None)
if args.rank == 0:
logger.info(f"Starting VicAI training with {args.world_size} GPUs")
logger.info(f"Arguments: {args}")
# Device
device = torch.device(f'cuda:{args.local_rank}' if torch.cuda.is_available() else 'cpu')
# Load tokenizer
if os.path.exists(args.tokenizer_path):
logger.info(f"Loading tokenizer from {args.tokenizer_path}")
tokenizer = ByteLevelBPETokenizer()
tokenizer.load(args.tokenizer_path)
else:
logger.warning(f"Tokenizer not found at {args.tokenizer_path}, creating default")
tokenizer = ByteLevelBPETokenizer(vocab_size=args.vocab_size)
# Train on sample data
if args.rank == 0:
sample_file = create_sample_corpus(num_articles=100)
with open(sample_file, 'r') as f:
texts = f.read().split('<|endoftext|>')
tokenizer.train([t for t in texts if t.strip()])
tokenizer.save(args.tokenizer_path)
if args.is_distributed:
dist.barrier()
if args.rank != 0:
tokenizer.load(args.tokenizer_path)
# Create model
logger.info("Creating model...")
config = VicAIConfig(
vocab_size=len(tokenizer),
dim=args.dim,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_kv_heads=args.n_kv_heads,
hidden_dim=args.hidden_dim,
max_seq_len=args.max_seq_len,
dropout=0.0,
)
if args.rank == 0:
logger.info(f"Model config: {config.__dict__}")
logger.info(f"Model parameters: ~{config.num_parameters / 1e9:.2f}B")
model = VicAIModel(config)
if args.use_fsdp and args.is_distributed:
model = FSDP(model, device_id=device)
elif args.is_distributed:
model = DDP(model, device_ids=[args.local_rank])
else:
model = model.to(device)
if args.compile and hasattr(torch, 'compile'):
logger.info("Compiling model...")
model = torch.compile(model)
# Create datasets
logger.info("Creating datasets...")
if os.path.exists(args.train_data):
train_dataset = TextFileDataset(args.train_data, tokenizer, args.max_seq_len)
val_dataset = TextFileDataset(args.val_data, tokenizer, args.max_seq_len) if os.path.exists(args.val_data) else train_dataset
else:
logger.warning("Training data not found, using Wikipedia streaming dataset")
train_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len)
val_dataset = WikipediaDataset(tokenizer, max_length=args.max_seq_len)
train_loader, train_sampler = get_data_loader(train_dataset, args.batch_size, args.world_size, args.rank)
val_loader, _ = get_data_loader(val_dataset, args.batch_size, args.world_size, args.rank, shuffle=False)
# Optimizer
optimizer = configure_optimizers(model, args)
# Learning rate scheduler
lr_scheduler = get_lr_scheduler(optimizer, args)
# Gradient scaler for AMP
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
# Resume from checkpoint
start_step = 0
if args.resume:
logger.info(f"Resuming from {args.resume}")
start_step = load_checkpoint(model, optimizer, scaler, args.resume, device)
# Training
logger.info("Starting training...")
final_step = train(
model,
train_loader,
val_loader,
optimizer,
lr_scheduler,
scaler,
device,
args,
logger,
)
# Save final model
if args.rank == 0:
save_checkpoint(
model,
optimizer,
scaler,
final_step,
0.0,
args.output_dir / 'final_model.pt',
)
logger.info("Training completed!")
cleanup_distributed()
if __name__ == '__main__':
main()