File size: 7,710 Bytes
52510e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """
torchrun --nproc_per_node 4 train.py --tp_size 4 --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name tp_naive --use_wandb
torchrun --nproc_per_node 4 train.py --tp_size 4 --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16
"""
import os
import time
import wandb
import datetime
import torch
import torch.nn.functional as F
import torch.distributed as dist
import argparse
from torch.optim import AdamW
from transformers import AutoConfig
# import lovely_tensors as lt
# lt.monkey_patch()
from model import Llama
from dataloader import MicroBatchDataLoader
import process_group_manager as pgm
from process_group_manager import setup_process_group_manager
from utils import set_all_seed, print, to_readable_format
from tensor_parallel import apply_tensor_parallel
def train_step(model, dataloader, device):
acc_loss = 0.0
for i in range(dataloader.grad_acc_steps):
# get the next batch
batch = next(dataloader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
outputs = model(input_ids=input_ids)
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps
loss.backward()
acc_loss += loss.item()
return acc_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training script for LLaMA model")
# Environment arguments
parser.add_argument("--omp_num_threads", type=str, default="1")
parser.add_argument("--tokenizers_parallelism", type=str, default="false")
# Model arguments
parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--num_hidden_layers", type=int, default=32)
parser.add_argument("--num_attention_heads", type=int, default=16)
parser.add_argument("--num_key_value_heads", type=int, default=4)
# Dataset arguments
parser.add_argument("--dataset_name", type=str, default="roneneldan/TinyStories")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--num_proc", type=int, default=4)
# Training arguments
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--seq_len", type=int, default=32)
parser.add_argument("--micro_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_tokens", type=int, default=1e6)
# Distributed training arguments
parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size")
parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size")
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size")
parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"])
# Logging arguments
parser.add_argument("--run_name", type=str, default="default_run")
parser.add_argument("--use_wandb", action="store_true")
args = parser.parse_args()
# Set environment variables
os.environ["OMP_NUM_THREADS"] = args.omp_num_threads
os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism
os.environ["DEVICE"] = "cpu"
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend = "nccl" if torch.cuda.is_available() else "gloo"
# torch.cuda.set_device(local_rank)
device = torch.device("cpu", local_rank)
dtype = torch.bfloat16
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2))
setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size)
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
set_all_seed(args.seed)
if is_wandb_rank and args.use_wandb:
wandb.init(
project="picotron_tutorial",
name=f"{args.run_name}_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_world_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_world_size,
"data_parallel_size": pgm.process_group_manager.dp_world_size,
"model": args.model_name,
"learning_rate": args.learning_rate,
"seed": args.seed,
},
)
model_config = AutoConfig.from_pretrained(args.model_name)
model_config.num_hidden_layers = args.num_hidden_layers
model_config.num_attention_heads = args.num_attention_heads
model_config.num_key_value_heads = args.num_key_value_heads
model_config.max_position_embeddings = args.seq_len
model = Llama(config=model_config)
if pgm.process_group_manager.tp_world_size > 1:
model = apply_tensor_parallel(model)
model.to(dtype).to(device)
model.train()
dist.barrier()
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
dist.barrier()
# Create dataloader
dataloader = MicroBatchDataLoader(
seq_len=args.seq_len,
micro_batch_size=args.micro_batch_size,
grad_acc_steps=args.gradient_accumulation_steps,
dataset_name=args.dataset_name,
tokenizer_name=args.model_name,
max_tokens=args.max_tokens,
num_workers=args.num_workers,
num_proc=args.num_proc,
)
tokens_per_step = dataloader.global_batch_size * args.seq_len
if pgm.process_group_manager.global_rank == 0:
print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank)
trained_token, step = 0, 0
dist.barrier()
# Training loop
while trained_token < args.max_tokens:
step_start_time = time.time()
optimizer.zero_grad()
loss = train_step(model, dataloader, device)
optimizer.step()
step_duration = time.time() - step_start_time
trained_token += tokens_per_step
step += 1
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, "
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
f"Tokens: {to_readable_format(trained_token)}{('/' + to_readable_format(args.max_tokens))}, "
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB"
, is_print_rank=is_wandb_rank
)
if is_wandb_rank and args.use_wandb:
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": tokens_per_step})
if is_wandb_rank and args.use_wandb:
wandb.finish()
dist.destroy_process_group() |