File size: 5,121 Bytes
52510e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
torchrun --nproc_per_node 2 train.py --tp_size 2 --run_name process_group_manager --use_wandb
torchrun --nproc_per_node 2 train.py --tp_size 2
"""
import os
import wandb
import datetime
import torch
import torch.nn.functional as F
import torch.distributed as dist
import argparse
from torch.optim import AdamW
from transformers import AutoConfig
from model import Llama
import process_group_manager as pgm
from process_group_manager import setup_process_group_manager
from utils import set_all_seed, print
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training script for LLaMA model")
# Environment arguments
parser.add_argument("--omp_num_threads", type=str, default="1")
parser.add_argument("--tokenizers_parallelism", type=str, default="false")
# Model arguments
parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--num_hidden_layers", type=int, default=32)
parser.add_argument("--num_attention_heads", type=int, default=16)
parser.add_argument("--num_key_value_heads", type=int, default=4)
# Training arguments
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--seq_len", type=int, default=32)
parser.add_argument("--micro_batch_size", type=int, default=1)
# Distributed training arguments
parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size")
parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size")
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size")
parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"])
# Logging arguments
parser.add_argument("--run_name", type=str, default="default_run")
parser.add_argument("--use_wandb", action="store_true")
args = parser.parse_args()
# Set environment variables
os.environ["OMP_NUM_THREADS"] = args.omp_num_threads
os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism
os.environ["DEVICE"] = "cpu"
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend = "nccl" if torch.cuda.is_available() else "gloo"
# torch.cuda.set_device(local_rank)
device = torch.device("cpu", local_rank)
dtype = torch.bfloat16
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2))
setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size)
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
set_all_seed(args.seed)
if is_wandb_rank and args.use_wandb:
wandb.init(
project="picotron_tutorial",
name=f"{args.run_name}_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_world_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_world_size,
"data_parallel_size": pgm.process_group_manager.dp_world_size,
"model": args.model_name,
"learning_rate": args.learning_rate,
"seed": args.seed,
},
)
model_config = AutoConfig.from_pretrained(args.model_name)
model_config.num_hidden_layers = args.num_hidden_layers
model_config.num_attention_heads = args.num_attention_heads
model_config.num_key_value_heads = args.num_key_value_heads
model_config.max_position_embeddings = args.seq_len
model = Llama(config=model_config)
model.to(dtype).to(device)
model.train()
dist.barrier()
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
dist.barrier()
# Create dummy data
input_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device)
target_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device)
# Training step
optimizer.zero_grad()
# Forward pass
outputs = model(input_ids=input_ids)
# Compute loss
target_ids = target_ids.reshape(-1)
outputs = outputs.view(-1, model_config.vocab_size)
loss = F.cross_entropy(outputs, target_ids)
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
print(f"Loss: {loss.item():.4f}", is_print_rank=(global_rank == 0))
print(f"[rank {pgm.process_group_manager.global_rank}], Loss: {loss:.4f}")
if is_wandb_rank and args.use_wandb:
wandb.log({"loss": loss.item()})
if is_wandb_rank and args.use_wandb:
wandb.finish()
dist.destroy_process_group() |