OpenGPT / train.py
VolodymyrPugachov's picture
Upload 17 files
6810eb1 verified
import os
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler
from model.gpt_model import GPTModel
from data.dataset import TextDataset
from data import utils
try:
import deepspeed
except ImportError:
deepspeed = None
try:
from torch.nn.parallel import DistributedDataParallel as DDP
except ImportError:
DDP = None
def main():
parser = argparse.ArgumentParser(description="Train the OpenGPT model.")
parser.add_argument("--config", type=str, required=True, help="Path to configuration file (YAML/JSON).")
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training.")
args = parser.parse_args()
# Load configuration
config = utils.load_config(args.config)
model_conf = config.get("model", {})
train_conf = config.get("training", {})
data_conf = config.get("data", {})
# Distributed setup
local_rank = args.local_rank
if local_rank == -1:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
distributed = False
if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1:
distributed = True
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
# Set random seed for reproducibility
utils.set_seed(train_conf.get("seed", 42))
# Prepare dataset and dataloader
train_dataset = TextDataset(data_conf["train_path"], data_conf["tokenizer_path"], data_conf.get("block_size", 128))
train_sampler = DistributedSampler(train_dataset) if distributed else None
train_loader = DataLoader(train_dataset, batch_size=train_conf.get("batch_size", 1),
sampler=train_sampler, shuffle=(train_sampler is None))
# Initialize model
model = GPTModel(vocab_size=model_conf["vocab_size"],
max_position_embeddings=model_conf.get("max_position_embeddings", 512),
n_layers=model_conf.get("n_layers", 12),
n_heads=model_conf.get("n_heads", 12),
hidden_dim=model_conf.get("embedding_dim", 768),
dropout=model_conf.get("dropout", 0.1)).to(device)
# Optionally load a pre-trained checkpoint to fine-tune
init_checkpoint = train_conf.get("init_checkpoint", "")
if init_checkpoint:
utils.load_checkpoint(model, optimizer=None, filepath=init_checkpoint, device=device)
# Create optimizer (AdamW by default)
optimizer = torch.optim.AdamW(model.parameters(), lr=train_conf.get("learning_rate", 5e-4),
weight_decay=train_conf.get("weight_decay", 0.0))
# Mixed precision training
mixed_precision = train_conf.get("mixed_precision", False) and torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
# Initialize DeepSpeed if enabled
use_deepspeed = False
ds_config_path = train_conf.get("deepspeed_config", None)
if ds_config_path and deepspeed is not None:
use_deepspeed = True
model, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config_path)
# If using DDP (and not DeepSpeed), wrap the model
if distributed and not use_deepspeed and DDP is not None:
model = DDP(model, device_ids=[local_rank])
# Training loop
epochs = train_conf.get("epochs", 1)
for epoch in range(epochs):
if distributed and train_sampler:
train_sampler.set_epoch(epoch)
model.train()
total_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs = inputs.to(device)
targets = targets.to(device)
if mixed_precision:
with torch.cuda.amp.autocast():
outputs = model(inputs) if not use_deepspeed else model(inputs)
loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1))
if use_deepspeed:
model.backward(loss)
model.step()
else:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
outputs = model(inputs) if not use_deepspeed else model(inputs)
loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1))
if use_deepspeed:
model.backward(loss)
model.step()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
# Print progress occasionally (only on rank 0 if distributed)
if batch_idx % 100 == 0 and (not distributed or torch.distributed.get_rank() == 0):
avg_loss = total_loss / (batch_idx + 1)
print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}")
# Save checkpoint at epoch end (only on rank 0)
if (not distributed or torch.distributed.get_rank() == 0):
ckpt_dir = train_conf.get("checkpoint_dir", "checkpoints")
os.makedirs(ckpt_dir, exist_ok=True)
if use_deepspeed:
model.save_checkpoint(ckpt_dir, tag=f"epoch-{epoch+1}")
else:
ckpt_path = os.path.join(ckpt_dir, f"epoch{epoch+1}.pt")
utils.save_checkpoint(model, optimizer, ckpt_path)
print(f"Checkpoint saved: {ckpt_path}")
if __name__ == "__main__":
main()