| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import functools |
| import logging |
| import os |
| from contextlib import nullcontext |
| from copy import deepcopy |
| from functools import partial |
| from itertools import chain |
|
|
| import torch |
| from codetiming import Timer |
| from omegaconf import DictConfig, open_dict |
| from tensordict import NonTensorData, TensorDict |
| from torch.distributed.device_mesh import init_device_mesh |
|
|
| from verl.checkpoint_engine import CheckpointEngineRegistry |
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register |
| from verl.utils import tensordict_utils as tu |
| from verl.utils.config import omega_conf_to_dataclass |
| from verl.utils.device import get_device_name, set_expandable_segments |
| from verl.utils.distributed import initialize_global_process_group_ray |
| from verl.utils.flops_counter import FlopsCounter |
| from verl.utils.memory_utils import aggressive_empty_cache |
| from verl.utils.metric.utils import Metric |
| from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage |
| from verl.utils.py_functional import append_to_dict |
| from verl.utils.tensordict_utils import maybe_fix_3d_position_ids |
| from verl.utils.torch_functional import allgather_dict_into_dict |
| from verl.workers.config import ActorConfig, HFModelConfig, MtpConfig, RolloutConfig, TrainingWorkerConfig |
| from verl.workers.rollout.base import BaseRollout, get_rollout_class |
| from verl.workers.utils.losses import ppo_loss |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| def _with_routing_replay_flag(enabled: bool): |
| """Decorator to set 'enable_routing_replay' flag on the data TensorDict.""" |
|
|
| def decorator(func): |
| @functools.wraps(func) |
| def wrapper(self, data: TensorDict, *args, **kwargs): |
| if self.enable_routing_replay: |
| tu.assign_non_tensor_data(data, "enable_routing_replay", enabled) |
| return func(self, data, *args, **kwargs) |
|
|
| return wrapper |
|
|
| return decorator |
|
|
|
|
| class TrainingWorker(Worker, DistProfilerExtension): |
| """ |
| TrainingWorker provides a Tinker-like API (https://thinkingmachines.ai/tinker/) as a RayWorkerGroup |
| to a single controller. Currently, we only provide more coarse grained APIs, |
| and do not provide exact APIs as Tinker does. But this can be added in the future. |
| """ |
|
|
| def __init__(self, config: TrainingWorkerConfig): |
| Worker.__init__(self) |
|
|
| from verl.workers.engine import BaseEngine, EngineRegistry |
|
|
| initialize_global_process_group_ray(timeout_second=None) |
|
|
| self.config = config |
| self.model_config = self.config.model_config |
| self.engine_config = self.config.engine_config |
| self.optimizer_config = self.config.optimizer_config |
| self.checkpoint_config = self.config.checkpoint_config |
| self.device_name = get_device_name() |
|
|
| if self.engine_config is None: |
| assert self.optimizer_config is None |
| if self.config.auto_select_engine_optim_fn is None: |
| raise ValueError( |
| "engine_config is not provided and auto_select_engine_optim_fn is not set. " |
| "Cannot determine engine backend." |
| ) |
| |
| self.engine_config, self.optimizer_config = self.config.auto_select_engine_optim_fn( |
| self.model_config, self.device_name |
| ) |
|
|
| |
| |
| self.engine_config.use_remove_padding = self.model_config.use_remove_padding |
| self.engine_config.use_fused_kernels = self.model_config.use_fused_kernels |
|
|
| |
| self.profiler_config = self.config.profiler_config |
| if self.profiler_config is not None: |
| self.profiler_tool_config = self.profiler_config.tool_config.get(self.profiler_config.tool, {}) |
| else: |
| self.profiler_tool_config = None |
|
|
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=self.profiler_tool_config) |
| ) |
|
|
| self.engine: BaseEngine = EngineRegistry.new( |
| model_type=self.config.model_type, |
| backend=self.engine_config.strategy, |
| model_config=self.model_config, |
| engine_config=self.engine_config, |
| optimizer_config=self.optimizer_config, |
| checkpoint_config=self.checkpoint_config, |
| ) |
|
|
| |
| self._register_dispatch_collect_info( |
| mesh_name="train", |
| dp_rank=self.engine.get_data_parallel_rank(), |
| is_collect=self.engine.is_mp_src_rank_with_outputs(), |
| ) |
|
|
| self.flops_counter = FlopsCounter(self.model_config.hf_config) |
|
|
| self.loss_fn = None |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def to(self, device, model=True, optimizer=True, grad=True): |
| """Manual control of load/offload""" |
| assert device in ["cpu", "device"] |
|
|
| if device == "device": |
| device = get_device_name() |
|
|
| self.engine.to(device=device, model=model, optimizer=optimizer, grad=grad) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def set_loss_fn(self, loss_fn): |
| self.loss_fn = loss_fn |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def reset(self): |
| """ |
| Reset the model engine to the initial state. If the engine is not initialized, |
| we initialize it. Otherwise, reload ckpt and reset states |
| """ |
| self.engine.initialize() |
|
|
| def _postprocess_output(self, output, *, global_token_num, delta_time, forward_only, images_seqlens): |
| """ |
| |
| Args: |
| output: a dictionary containing loss, model_outputs and metrics |
| |
| Returns: |
| |
| """ |
| |
| |
| |
| |
|
|
| metrics: dict = output.pop("metrics") |
| |
| |
| |
| loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name)) |
| dp_group = self.engine.get_data_parallel_group() |
| if dp_group is not None: |
| torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_group) |
| loss = loss.item() |
|
|
| |
| grad_norm = metrics.pop("grad_norm", None) |
| lr = metrics.pop("lr", None) |
|
|
| |
| if dp_group is not None: |
| final_metrics = allgather_dict_into_dict(data=metrics, group=dp_group) |
| else: |
| final_metrics = metrics |
| final_metrics["loss"] = loss |
| if grad_norm is not None: |
| final_metrics["grad_norm"] = grad_norm |
| if lr is not None: |
| final_metrics["lr"] = lr |
|
|
| |
| for k, v in final_metrics.items(): |
| if k.startswith("mtp_losses"): |
| flatten_v = [sublist[0] for sublist in v] |
| final_metrics[k] = sum(flatten_v) / len(flatten_v) |
| |
| if global_token_num is not None: |
| estimated_flops, promised_flops = self.flops_counter.estimate_flops( |
| global_token_num, delta_time, images_seqlens=images_seqlens |
| ) |
| final_metrics["mfu"] = estimated_flops / promised_flops / torch.distributed.get_world_size() |
| if forward_only: |
| final_metrics["mfu"] /= 3.0 |
| |
| model_output = output.pop("model_output", {}) |
| |
| final_output = tu.get_tensordict(tensor_dict=model_output, non_tensor_dict={"metrics": final_metrics}) |
| return final_output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) |
| def train_mini_batch(self, data: TensorDict) -> TensorDict: |
| """Split a batch into N mini-batches run for multiple epochs |
| |
| Args: |
| data: |
| |
| Returns: |
| |
| """ |
| maybe_fix_3d_position_ids(data) |
| batch_size_per_dp = data.shape[0] |
| disable_auto_offload = tu.pop(data, key="disable_auto_offload", default=False) |
| mini_batch_size = tu.pop(data, key="mini_batch_size", default=None) |
| num_mini_batch = tu.pop(data, key="num_mini_batch", default=None) |
| epochs = tu.pop(data, key="epochs", default=1) |
| seed = tu.pop(data, key="seed", default=42) |
| dataloader_kwargs = tu.pop(data, key="dataloader_kwargs", default={}) |
|
|
| assert mini_batch_size is not None or num_mini_batch is not None |
|
|
| if mini_batch_size is None: |
| assert batch_size_per_dp % num_mini_batch == 0, f"Got {batch_size_per_dp=} and {num_mini_batch=}" |
| mini_batch_size_per_gpu = batch_size_per_dp // num_mini_batch |
| else: |
| assert mini_batch_size % self.engine.get_data_parallel_size() == 0, ( |
| f"Got {mini_batch_size=} and {self.engine.get_data_parallel_size()=}" |
| ) |
| mini_batch_size_per_gpu = mini_batch_size // self.engine.get_data_parallel_size() |
|
|
| |
| dataloader = tu.make_iterator( |
| data, |
| mini_batch_size=mini_batch_size_per_gpu, |
| epochs=epochs, |
| seed=seed + self.engine.get_data_parallel_rank(), |
| dataloader_kwargs=dataloader_kwargs, |
| ) |
|
|
| with ( |
| self.engine.train_mode(disable_auto_offload=disable_auto_offload), |
| Timer(name="train_batch", logger=None), |
| ): |
| |
| output_lst = [] |
| total_num_iterations = data.shape[0] // mini_batch_size_per_gpu * epochs |
|
|
| for batch_idx, mini_batch_td in enumerate(dataloader): |
| |
| global_token_num = mini_batch_td["input_ids"].offsets().diff().tolist() |
| |
| global_token_num_output = [None] * self.engine.get_data_parallel_size() |
| torch.distributed.all_gather_object( |
| global_token_num_output, global_token_num, self.engine.get_data_parallel_group() |
| ) |
| global_token_num = [x for xs in global_token_num_output for x in xs] |
| tu.assign_non_tensor( |
| mini_batch_td, |
| global_token_num=NonTensorData(global_token_num), |
| update_lr_scheduler=batch_idx == total_num_iterations - 1, |
| disable_auto_offload=True, |
| ) |
| actor_output = self.train_batch(mini_batch_td) |
| output_lst.append(actor_output) |
|
|
| if self.engine.is_mp_src_rank_with_outputs(): |
| actor_output = [tu.get(output, "metrics") for output in output_lst] |
| metrics = {} |
| for output in actor_output: |
| for key, val in output.items(): |
| |
| if isinstance(val, list): |
| output[key] = ( |
| Metric.aggregate_dp(val) |
| if isinstance(val[0], Metric) |
| else list(chain.from_iterable(val)) |
| ) |
| append_to_dict(metrics, output) |
|
|
| output = tu.get_tensordict(tensor_dict={}, non_tensor_dict={"metrics": metrics}).cpu() |
| else: |
| output = None |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) |
| def train_batch(self, data: TensorDict) -> TensorDict: |
| assert self.loss_fn is not None, "loss function can't be None when calling train_batch" |
| assert not self.engine_config.forward_only, "Can't run `train_batch` when forward_only is in the engine config." |
| |
| global_token_num = tu.get(data, key="global_token_num") |
| disable_auto_offload = tu.get(data, key="disable_auto_offload", default=False) |
| images_seqlens = tu.get(data, key="images_seqlens", default=None) |
|
|
| |
| default_keys = dict( |
| use_remove_padding=self.model_config.use_remove_padding, |
| use_dynamic_bsz=self.engine_config.use_dynamic_bsz, |
| max_token_len_per_gpu=self.engine_config.max_token_len_per_gpu, |
| micro_batch_size_per_gpu=self.engine_config.micro_batch_size_per_gpu, |
| use_fused_kernels=self.engine_config.use_fused_kernels, |
| ) |
|
|
| for key, val in default_keys.items(): |
| if key not in data.keys(): |
| tu.assign_non_tensor(data, **{key: val}) |
|
|
| with ( |
| self.engine.train_mode(disable_auto_offload=disable_auto_offload), |
| Timer(name="train_batch", logger=None) as timer, |
| ): |
| output = self.engine.train_batch(data, loss_function=self.loss_fn) |
| |
| |
| delta_time = timer.last |
|
|
| update_lr_scheduler = tu.get(data, key="update_lr_scheduler", default=False) |
| |
| if update_lr_scheduler: |
| lr = self.engine.lr_scheduler_step() |
| else: |
| lr = None |
|
|
| if self.engine.is_mp_src_rank_with_outputs(): |
| |
| output.pop("model_output") |
| if lr is not None: |
| output["metrics"]["lr"] = lr |
| final_output = self._postprocess_output( |
| output, |
| global_token_num=global_token_num, |
| delta_time=delta_time, |
| forward_only=False, |
| images_seqlens=images_seqlens, |
| ).cpu() |
| else: |
| final_output = None |
|
|
| return final_output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) |
| def infer_batch(self, data: TensorDict) -> TensorDict: |
| |
| global_token_num = tu.get(data, key="global_token_num") |
| compute_loss = tu.get(data, key="compute_loss", default=True) |
| disable_auto_offload = tu.get(data, key="disable_auto_offload", default=False) |
| no_lora_adapter = tu.pop(data, key="no_lora_adapter", default=False) |
| images_seqlens = tu.get(data, key="images_seqlens", default=None) |
|
|
| default_keys = dict( |
| use_remove_padding=self.model_config.use_remove_padding, |
| use_dynamic_bsz=self.engine_config.use_dynamic_bsz, |
| max_token_len_per_gpu=self.engine_config.infer_max_token_len_per_gpu, |
| micro_batch_size_per_gpu=self.engine_config.infer_micro_batch_size_per_gpu, |
| use_fused_kernels=self.engine_config.use_fused_kernels, |
| ) |
|
|
| for key, val in default_keys.items(): |
| if key not in data.keys(): |
| tu.assign_non_tensor(data, **{key: val}) |
|
|
| |
| loss_function = self.loss_fn if compute_loss else None |
|
|
| with ( |
| self.engine.eval_mode(disable_auto_offload=disable_auto_offload), |
| Timer(name="eval_batch", logger=None) as timer, |
| ): |
| adapter_ctx = self.engine.disable_adapter() if no_lora_adapter else nullcontext() |
| with adapter_ctx: |
| output = self.engine.infer_batch(data, loss_function=loss_function) |
| delta_time = timer.last |
|
|
| if self.engine.is_mp_src_rank_with_outputs(): |
| final_output = self._postprocess_output( |
| output, |
| global_token_num=global_token_num, |
| delta_time=delta_time, |
| forward_only=True, |
| images_seqlens=images_seqlens, |
| ).cpu() |
| else: |
| final_output = None |
|
|
| return final_output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): |
| return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): |
| return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) |
|
|
|
|
| class ActorRolloutRefWorker(Worker, DistProfilerExtension): |
| """Hybrid worker that includes actor model, rollout and optional ref model. |
| For standalone actor or rollout, use ActorWorker or BaseRollout respectively. |
| |
| NOTE: ActorRolloutRefWorker no longer support spmd mode and run native server mode. |
| """ |
|
|
| def __init__(self, config: DictConfig, role: str, **kwargs): |
| Worker.__init__(self) |
| self.config = config |
| self.role = role |
| self.actor: TrainingWorker = None |
| self.ref: TrainingWorker = None |
| self.rollout: BaseRollout = None |
| assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] |
| self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] |
| self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] |
| self._is_ref = self.role in ["ref", "actor_rollout_ref"] |
|
|
| if self._is_actor: |
| omega_profiler_config = config.actor.get("profiler", {}) |
| elif self._is_rollout: |
| |
| |
| omega_profiler_config = config.rollout.get("profiler", {}) |
| else: |
| omega_profiler_config = config.ref.get("profiler", {}) |
|
|
| profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) |
| if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: |
| tool_config = omega_conf_to_dataclass( |
| omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) |
| ) |
| else: |
| tool_config = None |
|
|
| self.enable_routing_replay = ( |
| self.config.actor.strategy == "megatron" and self.config.actor.megatron.router_replay.mode != "disabled" |
| ) |
|
|
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) |
| ) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def set_loss_fn(self, loss_fn): |
| self.actor.set_loss_fn(loss_fn=loss_fn) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def to(self, device, model=True, optimizer=True, grad=True): |
| """Manual control of load/offload""" |
| self.actor.to(device=device, model=model, optimizer=optimizer, grad=grad) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model) |
|
|
| |
| if "ref" in self.role: |
| |
| with open_dict(self.config.ref): |
| self.config.ref.ppo_mini_batch_size = self.config.actor.ppo_mini_batch_size |
| self.config.ref.ppo_micro_batch_size = self.config.ref.pop("log_prob_micro_batch_size", None) |
| self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.pop( |
| "log_prob_micro_batch_size_per_gpu", None |
| ) |
| self.config.ref.use_dynamic_bsz = self.config.ref.pop("log_prob_use_dynamic_bsz", False) |
| self.config.ref.ppo_max_token_len_per_gpu = self.config.ref.pop("log_prob_max_token_len_per_gpu", None) |
| ref_config: ActorConfig = omega_conf_to_dataclass(self.config.ref) |
|
|
| |
| ref_config.model_config = deepcopy(model_config) |
| ref_config.model_config.mtp = MtpConfig(enable=False) |
|
|
| |
| ref_training_config = TrainingWorkerConfig( |
| model_type="language_model", |
| model_config=ref_config.model_config, |
| engine_config=ref_config.engine, |
| optimizer_config=ref_config.optim, |
| checkpoint_config=ref_config.checkpoint, |
| ) |
|
|
| |
| ref_training_config.engine_config.use_dynamic_bsz = self.config.ref.use_dynamic_bsz |
| ref_training_config.engine_config.infer_max_token_len_per_gpu = self.config.ref.ppo_max_token_len_per_gpu |
| ref_training_config.engine_config.infer_micro_batch_size_per_gpu = ( |
| self.config.ref.ppo_micro_batch_size_per_gpu |
| ) |
| ref_training_config.engine_config.use_remove_padding = model_config.use_remove_padding |
|
|
| self.ref = TrainingWorker(config=ref_training_config) |
| self.ref.reset() |
| self.set_dispatch_collect(mesh_name="ref", **self.ref.get_dispatch_collect()) |
|
|
| |
| if "actor" in self.role: |
| actor_config: ActorConfig = omega_conf_to_dataclass(self.config.actor) |
| actor_config.model_config = model_config |
| actor_training_config = TrainingWorkerConfig( |
| model_type="language_model", |
| model_config=actor_config.model_config, |
| engine_config=actor_config.engine, |
| optimizer_config=actor_config.optim, |
| checkpoint_config=actor_config.checkpoint, |
| ) |
|
|
| assert self.config.actor.use_dynamic_bsz == self.config.rollout.log_prob_use_dynamic_bsz |
|
|
| |
| actor_training_config.engine_config.use_dynamic_bsz = self.config.actor.use_dynamic_bsz |
| actor_training_config.engine_config.infer_max_token_len_per_gpu = ( |
| self.config.rollout.log_prob_max_token_len_per_gpu |
| ) |
| actor_training_config.engine_config.infer_micro_batch_size_per_gpu = ( |
| self.config.rollout.log_prob_micro_batch_size_per_gpu |
| ) |
| actor_training_config.engine_config.max_token_len_per_gpu = self.config.actor.ppo_max_token_len_per_gpu |
| actor_training_config.engine_config.micro_batch_size_per_gpu = ( |
| self.config.actor.ppo_micro_batch_size_per_gpu |
| ) |
| actor_training_config.engine_config.use_remove_padding = model_config.use_remove_padding |
|
|
| if self.config.actor.use_dynamic_bsz: |
| assert self.config.rollout.log_prob_max_token_len_per_gpu is not None |
| assert self.config.actor.ppo_max_token_len_per_gpu is not None |
| else: |
| assert self.config.rollout.log_prob_micro_batch_size_per_gpu is not None |
| assert self.config.actor.ppo_micro_batch_size_per_gpu is not None |
|
|
| self.loss_fn = partial(ppo_loss, config=actor_config) |
| self.actor = TrainingWorker(config=actor_training_config) |
| self.actor.reset() |
| self.actor.set_loss_fn(self.loss_fn) |
| self.set_dispatch_collect(mesh_name="actor", **self.actor.get_dispatch_collect()) |
|
|
| |
| if "rollout" in self.role: |
| rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) |
|
|
| |
| |
| infer_tp = rollout_config.tensor_model_parallel_size * rollout_config.data_parallel_size |
| infer_pp = rollout_config.pipeline_model_parallel_size |
| infer_world_size = infer_tp * infer_pp |
| dp = self.world_size // infer_world_size |
| assert self.world_size % infer_world_size == 0, ( |
| f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" |
| ) |
| rollout_device_mesh = init_device_mesh( |
| get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] |
| ) |
|
|
| |
| rollout_cls: type[BaseRollout] = get_rollout_class(rollout_config.name, rollout_config.mode) |
| self.rollout = rollout_cls( |
| config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh |
| ) |
|
|
| |
| self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format |
| self.layered_summon = self.config.rollout.get("layered_summon", False) |
| self.peft_merge: bool = model_config.lora.get("merge", False) |
|
|
| |
| if "actor" in self.role: |
| checkpoint_engine_config = omega_conf_to_dataclass(self.config.rollout.checkpoint_engine) |
| backend = checkpoint_engine_config.backend |
| bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 |
| engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) |
| self.checkpoint_engine = CheckpointEngineRegistry.new( |
| backend, is_master=(torch.distributed.get_rank() == 0), bucket_size=bucket_size, **engine_kwargs |
| ) |
|
|
| |
| aggressive_empty_cache(force_sync=True) |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="ref")) |
| @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") |
| @_with_routing_replay_flag(enabled=False) |
| def compute_ref_log_prob(self, data: TensorDict) -> TensorDict: |
| output = self.ref.infer_batch(data=data) |
| return output.cpu() if output is not None else None |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") |
| @_with_routing_replay_flag(enabled=True) |
| def compute_log_prob(self, data: TensorDict) -> TensorDict: |
| output = self.actor.infer_batch(data) |
|
|
| return output.cpu() if output is not None else None |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @DistProfiler.annotate(color="red", role="actor_update") |
| @_with_routing_replay_flag(enabled=True) |
| def update_actor(self, data: TensorDict) -> TensorDict: |
| output = self.actor.train_mini_batch(data=data) |
| return output.cpu() if output is not None else None |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): |
| assert "actor" in self.role, "load_checkpoint only support actor role" |
| self.actor.load_checkpoint(local_path, hdfs_path, del_local_after_load) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): |
| assert "actor" in self.role, "save_checkpoint only support actor role" |
| self.actor.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) |
| async def update_weights(self, global_steps: int = None): |
| """Update weights from trainer to rollout. |
| |
| 1. For sync training with colocated trainer and rollout, update rollout directly from model engine. |
| - before update_weights: rollout should be in sleep mode. |
| - after update_weights: rollout should be in wake_up mode. |
| 2. For async training with disaggregated trainer and rollout, send_weights only by checkpoint engine. |
| """ |
|
|
| |
| if self.config.rollout.checkpoint_engine.backend != "naive": |
| per_tensor_param, _ = self.actor.engine.get_per_tensor_param() |
| await self.checkpoint_engine.send_weights(per_tensor_param) |
| return |
|
|
| set_expandable_segments(False) |
| log_gpu_memory_usage("Before resume weights", logger=logger) |
|
|
| |
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["weights"]) |
| log_gpu_memory_usage("After resume weights", logger=logger) |
|
|
| |
| per_tensor_param, peft_config = self.actor.engine.get_per_tensor_param( |
| layered_summon=self.layered_summon, base_sync_done=True |
| ) |
|
|
| await self.rollout.update_weights( |
| per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps |
| ) |
|
|
| do_lora_base_sync = False |
| if not self.peft_merge and peft_config is not None: |
| |
| |
| |
| self.rollout.sleep_level = 1 |
|
|
| do_lora_base_sync = (not self.base_sync_done) or ( |
| self.rollout.sleep_level != 1 and self.config.rollout.free_cache_engine |
| ) |
|
|
| if do_lora_base_sync: |
| per_tensor_base_params, _ = self.actor.engine.get_per_tensor_param( |
| layered_summon=self.layered_summon, base_sync_done=False |
| ) |
| await self.rollout.update_weights(per_tensor_base_params, peft_config=peft_config, base_sync_done=False) |
|
|
| log_gpu_memory_usage("After update_weights", logger=logger) |
|
|
| |
| self.actor.engine.to("cpu", model=True, optimizer=False, grad=False) |
| aggressive_empty_cache(force_sync=True) |
|
|
| |
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["kv_cache"]) |
| log_gpu_memory_usage("After resume kv_cache", logger=logger) |
|
|
| self.base_sync_done = True |
| set_expandable_segments(True) |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) |
| def execute_checkpoint_engine(self, method: str, *args, **kwargs): |
| """Execute checkpoint engine method. |
| |
| Args: |
| method (str): Checkpoint engine method name. |
| *args: Variable length argument list. |
| **kwargs: Arbitrary keyword arguments. |
| |
| """ |
| return getattr(self.checkpoint_engine, method)(*args, **kwargs) |
|
|