|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
A lightweight one-file FSDP SFT Trainer
|
|
|
TODO(zhangchi.usc1992)
|
|
|
- Add calculation of mfu
|
|
|
- Add validation
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
os.environ["NCCL_DEBUG"] = "WARN"
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
import logging
|
|
|
import re
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
import hydra
|
|
|
import torch
|
|
|
import torch.distributed
|
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
|
|
|
from peft import LoraConfig, TaskType, get_peft_model
|
|
|
from tensordict import TensorDict
|
|
|
from torch import nn, optim
|
|
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
|
|
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
|
|
|
|
|
import verl.utils.hdfs_io as hdfs_io
|
|
|
from verl.utils.dataset import SFTDataset
|
|
|
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
|
|
|
from verl.utils.debug import log_gpu_memory_usage
|
|
|
from verl.utils.distributed import initialize_global_process_group
|
|
|
from verl.utils.fs import copy_to_local
|
|
|
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn
|
|
|
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
|
|
|
from verl.utils.tracking import Tracking
|
|
|
from verl.utils.ulysses import (
|
|
|
gather_outpus_and_unpad,
|
|
|
get_ulysses_sequence_parallel_world_size,
|
|
|
ulysses_pad_and_slice_inputs,
|
|
|
)
|
|
|
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN"))
|
|
|
|
|
|
|
|
|
def extract_step(path):
|
|
|
match = re.search(r"global_step_(\d+)", path)
|
|
|
if match:
|
|
|
return int(match.group(1))
|
|
|
return None
|
|
|
|
|
|
|
|
|
def convert_to_regular_types(obj):
|
|
|
"""Convert Hydra configs and other special types to regular Python types."""
|
|
|
from omegaconf import DictConfig, ListConfig
|
|
|
|
|
|
if isinstance(obj, (ListConfig, DictConfig)):
|
|
|
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
|
|
|
elif isinstance(obj, (list, tuple)):
|
|
|
return [convert_to_regular_types(x) for x in obj]
|
|
|
elif isinstance(obj, dict):
|
|
|
return {k: convert_to_regular_types(v) for k, v in obj.items()}
|
|
|
return obj
|
|
|
|
|
|
|
|
|
class FSDPSFTTrainer:
|
|
|
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset):
|
|
|
self.config = config
|
|
|
self.device_mesh = device_mesh
|
|
|
self.ulysses_device_mesh = ulysses_device_mesh
|
|
|
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
|
|
|
self.tokenizer = tokenizer
|
|
|
if self.config.data.chat_template is not None:
|
|
|
raise ValueError("Apply Chat template from config is not supported yet.")
|
|
|
|
|
|
|
|
|
self._normalize_config_bsz()
|
|
|
|
|
|
|
|
|
self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1)
|
|
|
self.use_remove_padding = getattr(self.config, "use_remove_padding", False)
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
|
|
|
print(f"Using remove padding: {self.use_remove_padding}")
|
|
|
|
|
|
self._build_dataloader(train_dataset, val_dataset)
|
|
|
|
|
|
self._build_model_optimizer()
|
|
|
|
|
|
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(self.config)
|
|
|
|
|
|
def _normalize_config_bsz(self):
|
|
|
dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(f"Normalize batch size by dp {dp_size}")
|
|
|
|
|
|
assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"
|
|
|
|
|
|
self.config.data.train_batch_size //= dp_size
|
|
|
|
|
|
assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
|
|
|
|
|
|
def _build_dataloader(self, train_dataset, val_dataset):
|
|
|
|
|
|
config = self.config
|
|
|
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.ulysses_sequence_parallel_size > 1:
|
|
|
rank = self.ulysses_device_mesh.get_local_rank("dp")
|
|
|
world_size = self.ulysses_device_mesh.size(0)
|
|
|
if self.ulysses_device_mesh.get_rank() == 0:
|
|
|
print(f"Using SP rank {rank} and size {world_size} for data distribution")
|
|
|
print("Each SP rank gets different data, but the same data WITHIN the same rank")
|
|
|
else:
|
|
|
rank = self.device_mesh.get_rank()
|
|
|
world_size = self.device_mesh.size()
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(f"Using FSDP rank {rank} and size {world_size} for data distribution")
|
|
|
|
|
|
self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True)
|
|
|
self.train_dataloader = DataLoader(
|
|
|
dataset=self.train_dataset,
|
|
|
batch_size=config.data.train_batch_size,
|
|
|
sampler=self.train_sampler,
|
|
|
num_workers=8,
|
|
|
pin_memory=True,
|
|
|
drop_last=True,
|
|
|
)
|
|
|
|
|
|
self.val_sampler = DistributedSampler(self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True)
|
|
|
self.val_dataloader = DataLoader(
|
|
|
dataset=self.val_dataset,
|
|
|
batch_size=config.data.micro_batch_size_per_gpu,
|
|
|
sampler=self.val_sampler,
|
|
|
num_workers=8,
|
|
|
pin_memory=True,
|
|
|
drop_last=True,
|
|
|
)
|
|
|
|
|
|
def _build_model_optimizer(self):
|
|
|
|
|
|
|
|
|
|
|
|
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
|
|
|
|
|
|
if self.config.model.get("external_lib", None) is not None:
|
|
|
|
|
|
import importlib
|
|
|
|
|
|
importlib.import_module(self.config.model.external_lib)
|
|
|
|
|
|
log_gpu_memory_usage("Before model allocation", logger=logger)
|
|
|
|
|
|
trust_remote_code = self.config.model.trust_remote_code
|
|
|
|
|
|
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
|
|
|
if self.config.ulysses_sequence_parallel_size > 1:
|
|
|
assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"
|
|
|
|
|
|
|
|
|
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh)
|
|
|
|
|
|
with init_context():
|
|
|
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
|
|
|
local_model_path,
|
|
|
config=config,
|
|
|
torch_dtype=torch.float32,
|
|
|
attn_implementation="flash_attention_2",
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
)
|
|
|
|
|
|
if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
|
|
|
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
|
|
|
|
|
apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)
|
|
|
|
|
|
|
|
|
if self.config.model.get("use_liger", False):
|
|
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
|
|
|
|
|
|
_apply_liger_kernel_to_instance(model=self.model)
|
|
|
|
|
|
if self.config.model.get("lora_rank", 0) > 0:
|
|
|
self.model.enable_input_require_grads()
|
|
|
|
|
|
lora_config = {
|
|
|
"task_type": TaskType.CAUSAL_LM,
|
|
|
"r": self.config.model.lora_rank,
|
|
|
"lora_alpha": self.config.model.lora_alpha,
|
|
|
"target_modules": convert_to_regular_types(self.config.model.target_modules),
|
|
|
"bias": "none",
|
|
|
}
|
|
|
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
|
|
|
|
|
|
if self.config.model.enable_gradient_checkpointing:
|
|
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
|
|
|
|
log_gpu_memory_usage("After model allocation", logger=logger)
|
|
|
|
|
|
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
|
|
|
|
|
|
auto_wrap_policy = get_fsdp_wrap_policy(
|
|
|
self.model,
|
|
|
config=self.config.model.fsdp_config.wrap_policy,
|
|
|
is_lora=self.config.model.get("lora_rank", 0) > 0,
|
|
|
)
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(auto_wrap_policy)
|
|
|
|
|
|
if not self.config.model.fsdp_config.cpu_offload:
|
|
|
cpu_offload = None
|
|
|
else:
|
|
|
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
|
|
|
|
|
|
self.fsdp_model = FSDP(
|
|
|
module=self.model,
|
|
|
auto_wrap_policy=auto_wrap_policy,
|
|
|
param_init_fn=init_fn,
|
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
|
mixed_precision=mixed_precision,
|
|
|
device_mesh=self.device_mesh,
|
|
|
sync_module_states=True,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
cpu_offload=cpu_offload,
|
|
|
use_orig_params=False,
|
|
|
)
|
|
|
|
|
|
log_gpu_memory_usage("After FSDP wrapping", logger=logger)
|
|
|
|
|
|
self.optimizer = optim.AdamW(
|
|
|
self.fsdp_model.parameters(),
|
|
|
lr=self.config.optim.lr,
|
|
|
betas=self.config.optim.betas,
|
|
|
weight_decay=self.config.optim.weight_decay,
|
|
|
)
|
|
|
|
|
|
log_gpu_memory_usage("After initialize optimizer", logger=logger)
|
|
|
|
|
|
self.steps_per_epoch = len(self.train_dataloader)
|
|
|
self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs
|
|
|
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
print(f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}")
|
|
|
|
|
|
num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)
|
|
|
|
|
|
if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine":
|
|
|
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps)
|
|
|
elif self.config.optim.lr_scheduler == "wsd":
|
|
|
self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}")
|
|
|
|
|
|
def _compute_loss_and_backward(self, batch, do_backward=True):
|
|
|
"""Compute loss with optional sequence parallelism and remove padding features"""
|
|
|
use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1
|
|
|
|
|
|
|
|
|
input_ids = batch["input_ids"].cuda()
|
|
|
attention_mask = batch["attention_mask"].cuda()
|
|
|
position_ids = batch["position_ids"].cuda()
|
|
|
loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda()
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
|
|
|
|
|
|
|
|
context = self.sharding_manager if use_sp else nullcontext()
|
|
|
with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
|
|
if not use_sp:
|
|
|
|
|
|
labels = input_ids[:, 1:].contiguous()
|
|
|
output = self.fsdp_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
|
|
|
logits = output.logits
|
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
shift_labels = labels.contiguous()
|
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
|
|
|
shift_labels = shift_labels.view(-1)
|
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
loss = loss * loss_mask.to(loss.device)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, seqlen = input_ids.shape
|
|
|
|
|
|
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
|
|
|
input_ids_rmpad = input_ids_rmpad.transpose(0, 1)
|
|
|
|
|
|
|
|
|
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)
|
|
|
|
|
|
|
|
|
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
|
|
|
|
|
|
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)
|
|
|
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size())
|
|
|
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
|
|
|
|
|
|
|
|
|
output = self.fsdp_model(
|
|
|
input_ids=input_ids_rmpad_sliced,
|
|
|
attention_mask=None,
|
|
|
position_ids=position_ids_rmpad_padded,
|
|
|
use_cache=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
logits_rmpad = output.logits.squeeze(0)
|
|
|
input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
|
|
|
loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
|
|
|
|
|
|
loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
|
|
|
|
|
|
|
|
|
full_loss = pad_input(hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen)
|
|
|
full_loss = full_loss.squeeze(-1)[:, :-1]
|
|
|
full_loss = full_loss.reshape(-1)
|
|
|
loss_mask = loss_mask.to(full_loss.device)
|
|
|
loss = full_loss * loss_mask
|
|
|
|
|
|
valid_token_this_rank = torch.sum(loss_mask)
|
|
|
|
|
|
if self.config.data.balance_dp_token:
|
|
|
torch.distributed.all_reduce(valid_token_this_rank)
|
|
|
dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size()
|
|
|
else:
|
|
|
dp_size = 1
|
|
|
|
|
|
loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size
|
|
|
|
|
|
if do_backward:
|
|
|
loss.backward()
|
|
|
return loss
|
|
|
|
|
|
def training_step(self, batch: TensorDict):
|
|
|
self.fsdp_model.train()
|
|
|
|
|
|
log_gpu_memory_usage("Before optimizer zero_grad", logger=logger)
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
log_gpu_memory_usage("After optimizer zero_grad", logger=logger)
|
|
|
|
|
|
micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
|
|
|
n_micro_batches = len(micro_batches)
|
|
|
step_loss = 0
|
|
|
for micro_batch in micro_batches:
|
|
|
loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches
|
|
|
step_loss += loss.item()
|
|
|
|
|
|
grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
|
|
|
|
|
|
log_gpu_memory_usage("Before optimizer step", logger=logger)
|
|
|
|
|
|
|
|
|
if not torch.isfinite(grad_norm):
|
|
|
print(f"WARN: grad_norm is not finite: {grad_norm}")
|
|
|
self.optimizer.zero_grad()
|
|
|
else:
|
|
|
self.optimizer.step()
|
|
|
|
|
|
log_gpu_memory_usage("After optimizer step", logger=logger)
|
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|
|
|
|
|
lr = self.lr_scheduler.get_last_lr()[0]
|
|
|
|
|
|
log_gpu_memory_usage("After offload weights", logger=logger)
|
|
|
|
|
|
step_loss = torch.tensor(step_loss).cuda()
|
|
|
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
|
|
|
return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3}
|
|
|
|
|
|
def validation_step(self, batch: TensorDict):
|
|
|
self.fsdp_model.eval()
|
|
|
with torch.no_grad():
|
|
|
loss = self._compute_loss_and_backward(batch, do_backward=False)
|
|
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
|
|
|
return loss
|
|
|
|
|
|
def save_checkpoint(self, step):
|
|
|
|
|
|
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
|
|
|
|
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
|
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
|
|
|
state_dict = self.fsdp_model.state_dict()
|
|
|
|
|
|
path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}")
|
|
|
|
|
|
if self.device_mesh.get_rank() == 0:
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
self.model.save_pretrained(path, state_dict=state_dict)
|
|
|
self.tokenizer.save_pretrained(path)
|
|
|
if self.config.trainer.default_hdfs_dir:
|
|
|
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
|
|
|
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def fit(self):
|
|
|
rank = self.device_mesh.get_rank()
|
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
tracking = Tracking(
|
|
|
project_name=self.config.trainer.project_name,
|
|
|
experiment_name=self.config.trainer.experiment_name,
|
|
|
default_backend=self.config.trainer.logger,
|
|
|
)
|
|
|
|
|
|
global_step = 0
|
|
|
|
|
|
|
|
|
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
|
|
|
|
|
if self.config.trainer.total_training_steps is not None:
|
|
|
total_training_steps = self.config.trainer.total_training_steps
|
|
|
|
|
|
self.total_training_steps = total_training_steps
|
|
|
print(f"Total training steps: {self.total_training_steps}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(self.config.trainer.total_epochs):
|
|
|
self.train_sampler.set_epoch(epoch=epoch)
|
|
|
for data in tqdm(
|
|
|
self.train_dataloader,
|
|
|
total=self.steps_per_epoch,
|
|
|
desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}",
|
|
|
):
|
|
|
global_step += 1
|
|
|
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
|
|
|
metric = self.training_step(data)
|
|
|
if rank == 0:
|
|
|
tracking.log(data=metric, step=global_step)
|
|
|
|
|
|
|
|
|
if global_step >= self.total_training_steps:
|
|
|
|
|
|
val_losses = []
|
|
|
for val_data in self.val_dataloader:
|
|
|
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
|
|
|
val_loss = self.validation_step(val_data)
|
|
|
val_losses.append(val_loss)
|
|
|
if rank == 0:
|
|
|
avg_val_loss = torch.mean(torch.stack(val_losses))
|
|
|
metric = {"val/loss": avg_val_loss.detach().item()}
|
|
|
tracking.log(data=metric, step=global_step)
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
|
|
|
self.save_checkpoint(step=global_step)
|
|
|
return
|
|
|
|
|
|
|
|
|
val_losses = []
|
|
|
for data in self.val_dataloader:
|
|
|
data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
|
|
|
val_loss = self.validation_step(data)
|
|
|
val_losses.append(val_loss)
|
|
|
if rank == 0:
|
|
|
val_loss = torch.mean(torch.stack(val_losses))
|
|
|
metric = {"val/loss": val_loss.detach().item()}
|
|
|
tracking.log(data=metric, step=global_step)
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
|
|
|
self.save_checkpoint(step=global_step)
|
|
|
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="sft_trainer", version_base=None)
|
|
|
def main(config):
|
|
|
local_rank, rank, world_size = initialize_global_process_group()
|
|
|
|
|
|
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
|
|
|
dp_size = world_size // config.ulysses_sequence_parallel_size
|
|
|
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
|
|
|
|
|
|
from verl.utils import hf_tokenizer
|
|
|
|
|
|
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
|
|
|
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
|
|
|
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
|
|
|
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
|
|
|
|
|
|
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)
|
|
|
|
|
|
trainer.fit()
|
|
|
|
|
|
|
|
|
def create_sft_dataset(data_paths, data_config, tokenizer):
|
|
|
"""Create a dataset."""
|
|
|
|
|
|
|
|
|
if data_config.custom_cls.get("path", None):
|
|
|
from verl.utils.import_utils import load_extern_type
|
|
|
|
|
|
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
|
|
|
|
|
|
elif data_config.get("multiturn", {}).get("enable", False):
|
|
|
dataset_cls = MultiTurnSFTDataset
|
|
|
|
|
|
else:
|
|
|
dataset_cls = SFTDataset
|
|
|
|
|
|
|
|
|
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|