| | import os |
| | import argparse |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, DistributedSampler |
| | from model.gpt_model import GPTModel |
| | from data.dataset import TextDataset |
| | from data import utils |
| |
|
| | try: |
| | import deepspeed |
| | except ImportError: |
| | deepspeed = None |
| |
|
| | try: |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| | except ImportError: |
| | DDP = None |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Train the OpenGPT model.") |
| | parser.add_argument("--config", type=str, required=True, help="Path to configuration file (YAML/JSON).") |
| | parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training.") |
| | args = parser.parse_args() |
| |
|
| | |
| | config = utils.load_config(args.config) |
| | model_conf = config.get("model", {}) |
| | train_conf = config.get("training", {}) |
| | data_conf = config.get("data", {}) |
| |
|
| | |
| | local_rank = args.local_rank |
| | if local_rank == -1: |
| | local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| | distributed = False |
| | if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1: |
| | distributed = True |
| | torch.distributed.init_process_group(backend="nccl", init_method="env://") |
| | device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | torch.cuda.set_device(local_rank) |
| |
|
| | |
| | utils.set_seed(train_conf.get("seed", 42)) |
| |
|
| | |
| | train_dataset = TextDataset(data_conf["train_path"], data_conf["tokenizer_path"], data_conf.get("block_size", 128)) |
| | train_sampler = DistributedSampler(train_dataset) if distributed else None |
| | train_loader = DataLoader(train_dataset, batch_size=train_conf.get("batch_size", 1), |
| | sampler=train_sampler, shuffle=(train_sampler is None)) |
| |
|
| | |
| | model = GPTModel(vocab_size=model_conf["vocab_size"], |
| | max_position_embeddings=model_conf.get("max_position_embeddings", 512), |
| | n_layers=model_conf.get("n_layers", 12), |
| | n_heads=model_conf.get("n_heads", 12), |
| | hidden_dim=model_conf.get("embedding_dim", 768), |
| | dropout=model_conf.get("dropout", 0.1)).to(device) |
| |
|
| | |
| | init_checkpoint = train_conf.get("init_checkpoint", "") |
| | if init_checkpoint: |
| | utils.load_checkpoint(model, optimizer=None, filepath=init_checkpoint, device=device) |
| |
|
| | |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=train_conf.get("learning_rate", 5e-4), |
| | weight_decay=train_conf.get("weight_decay", 0.0)) |
| |
|
| | |
| | mixed_precision = train_conf.get("mixed_precision", False) and torch.cuda.is_available() |
| | scaler = torch.cuda.amp.GradScaler() if mixed_precision else None |
| |
|
| | |
| | use_deepspeed = False |
| | ds_config_path = train_conf.get("deepspeed_config", None) |
| | if ds_config_path and deepspeed is not None: |
| | use_deepspeed = True |
| | model, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config_path) |
| |
|
| | |
| | if distributed and not use_deepspeed and DDP is not None: |
| | model = DDP(model, device_ids=[local_rank]) |
| |
|
| | |
| | epochs = train_conf.get("epochs", 1) |
| | for epoch in range(epochs): |
| | if distributed and train_sampler: |
| | train_sampler.set_epoch(epoch) |
| | model.train() |
| | total_loss = 0.0 |
| | for batch_idx, (inputs, targets) in enumerate(train_loader): |
| | inputs = inputs.to(device) |
| | targets = targets.to(device) |
| | if mixed_precision: |
| | with torch.cuda.amp.autocast(): |
| | outputs = model(inputs) if not use_deepspeed else model(inputs) |
| | loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1)) |
| | if use_deepspeed: |
| | model.backward(loss) |
| | model.step() |
| | else: |
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | optimizer.zero_grad() |
| | else: |
| | outputs = model(inputs) if not use_deepspeed else model(inputs) |
| | loss = F.cross_entropy(outputs.view(-1, model_conf["vocab_size"]), targets.view(-1)) |
| | if use_deepspeed: |
| | model.backward(loss) |
| | model.step() |
| | else: |
| | loss.backward() |
| | optimizer.step() |
| | optimizer.zero_grad() |
| | total_loss += loss.item() |
| | |
| | if batch_idx % 100 == 0 and (not distributed or torch.distributed.get_rank() == 0): |
| | avg_loss = total_loss / (batch_idx + 1) |
| | print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}") |
| | |
| | if (not distributed or torch.distributed.get_rank() == 0): |
| | ckpt_dir = train_conf.get("checkpoint_dir", "checkpoints") |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| | if use_deepspeed: |
| | model.save_checkpoint(ckpt_dir, tag=f"epoch-{epoch+1}") |
| | else: |
| | ckpt_path = os.path.join(ckpt_dir, f"epoch{epoch+1}.pt") |
| | utils.save_checkpoint(model, optimizer, ckpt_path) |
| | print(f"Checkpoint saved: {ckpt_path}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|