| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| from abc import ABC, abstractmethod |
| from typing import Any, Generator, TypedDict |
|
|
| import ray |
| import torch |
|
|
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, register |
| from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup |
| from verl.utils.distributed import initialize_global_process_group_ray |
| from verl.utils.ray_utils import auto_await |
| from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig |
| from verl.workers.rollout import BaseRollout, RolloutReplica, get_rollout_class |
|
|
|
|
| class TensorMeta(TypedDict): |
| name: str |
| shape: torch.Size |
| dtype: torch.dtype |
| offset: int |
|
|
|
|
| class CheckpointEngineRegistry: |
| """Checkpoint engine registry.""" |
|
|
| _registry: dict[str, type["CheckpointEngine"]] = {} |
|
|
| def register(backend: str): |
| """Register a checkpoint engine. |
| |
| Args: |
| backend: The backend of the checkpoint engine. |
| """ |
|
|
| def wrapper(cls: type["CheckpointEngine"]): |
| CheckpointEngineRegistry._registry[backend] = cls |
| return cls |
|
|
| return wrapper |
|
|
| @classmethod |
| def get(cls, backend: str) -> type["CheckpointEngine"]: |
| """Get the checkpoint engine class. |
| |
| Args: |
| backend: The backend of the checkpoint engine. |
| |
| Returns: |
| The checkpoint engine class. |
| """ |
| return cls._registry[backend] |
|
|
| @classmethod |
| def new(cls, backend: str, *args, **kwargs) -> "CheckpointEngine": |
| """Create a new checkpoint engine instance. |
| |
| Args: |
| backend: The backend of the checkpoint engine. |
| *args: Variable length argument pass to the checkpoint engine constructor. |
| **kwargs: Arbitrary keyword arguments pass to the checkpoint engine constructor. |
| |
| Returns: |
| A new checkpoint engine instance. |
| """ |
| if backend not in cls._registry: |
| raise ValueError(f"Checkpoint engine {backend} not registered") |
| return cls._registry[backend](*args, **kwargs) |
|
|
|
|
| class CheckpointEngine(ABC): |
| """CheckpointEngine is an abstraction to transfer weights from trainer to rollout. |
| |
| In trainer process: |
| >>> trainer = EngineRegistry.new(...) # FSDP, Megatron, VeOmini, TorchTitan, ... |
| >>> engine = CheckpointEngine.new(...) # NCCLCheckpointEngine, NIXLCheckpointEngine, ... |
| >>> await engine.send_weights(trainer.get_per_tensor_param()) |
| |
| In rollout process: |
| >>> engine = CheckpointEngine.new(...) |
| >>> server_adapter = ServerAdapter() |
| >>> await server_adapter.update_weights(engine.get_weights()) # update weights via cuda ipc |
| """ |
|
|
| @abstractmethod |
| def prepare(self) -> dict[str, Any]: |
| """Prepare checkpoint engine before each step send_weights/receive_weights. |
| |
| 1. Allocate weight bucket. |
| 2. [Optional] Register weight bucket for RDMA. |
| 3. Return metadata to build communication topology: master ip:port, register RDMA description, etc. |
| |
| Args: |
| worker_group: The worker group that the checkpoint engine will be used. |
| |
| Returns: |
| A dictionary that contains the metadata of the worker group. |
| """ |
| raise NotImplementedError |
|
|
| @classmethod |
| @abstractmethod |
| def build_topology( |
| cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict] |
| ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]: |
| """Build communication topology between all workers. |
| |
| Args: |
| trainer_world_size: The world size of the trainer worker group. |
| rollout_world_size: The world size of the rollout replica. |
| metadata: A list of metadata `prepare` from all workers. |
| |
| Returns: |
| A tuple of two dictionaries that contains the communication topology for trainer and rollout worker group. |
| Each dict value should be a list argument equal to the world size of the worker group to dispatch to |
| `init_process_group`. |
| |
| ``` |
| world_size = rollout.world_size + trainer.world_size |
| kwargs = { |
| "rank": list(range(world_size)), |
| "world_size": [world_size] * world_size, |
| "master_metadata": [metadata[0]] * world_size, |
| } |
| ``` |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def init_process_group(self, **kwargs): |
| """Init process group for checkpoint engine. |
| |
| Args: |
| **kwargs: Keyword arguments from `build_topology`. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def finalize(self): |
| """Finalize checkpoint engine after each step send_weights/receive_weights. |
| |
| 1. Free weight bucket. |
| 1. [Optional] Deregister weight bucket for RDMA. |
| 2. [Optional] Destroy process group. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| """Send the weights of the model. |
| |
| Args: |
| weights: A generator that yields the name of the weight tensor and the tensor itself. |
| """ |
| raise NotImplementedError |
|
|
| @abstractmethod |
| async def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: |
| """Receive the weights of the model. |
| |
| Yields: |
| A tuple of the name of the weight tensor and the tensor itself. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class CheckpointEngineWithCache(CheckpointEngine): |
| """Checkpoint engine with local cache: shm, disk, etc. This allow to synchronize weights without interrupting |
| rollout ongoing requests (partial rollout). After requests exhausted, rollout can get weights from local cache. |
| |
| Laminar: https://arxiv.org/abs/2510.12633 |
| """ |
|
|
| @abstractmethod |
| async def get_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: |
| """Get the weights of the model from local cache. |
| |
| Yields: |
| A tuple of the name of the weight tensor and the tensor itself. |
| """ |
| raise NotImplementedError |
|
|
|
|
| @CheckpointEngineRegistry.register("naive") |
| class ColocatedCheckpointEngine(CheckpointEngine): |
| """Checkpoint engine for trainer and rollout colocated on same GPU. |
| |
| In trainer process: |
| >>> engine = ColocatedCheckpointEngine() |
| >>> trainer = Trainer() |
| >>> server_adapter = ServerAdapter() |
| >>> engine.send_weights(trainer.get_per_tensor_param()) |
| >>> server_adapter.update_weights(engine.receive_weights()) |
| """ |
|
|
| def __init__(self, bucket_size: int, is_master: bool = False) -> None: |
| self.bucket_size = bucket_size |
| self.is_master = is_master |
|
|
| def prepare(self): |
| raise NotImplementedError |
|
|
| def init_process_group(self, **kwargs): |
| raise NotImplementedError |
|
|
| def finalize(self): |
| raise NotImplementedError |
|
|
| @classmethod |
| def build_topology(cls, *args, **kwargs): |
| raise NotImplementedError |
|
|
| def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| """Send the weights of the model. |
| |
| Args: |
| weights: A generator that yields the name of the weight tensor and the tensor itself. |
| """ |
| self.weights = weights |
|
|
| def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]: |
| """Receive the weights of the model. |
| |
| Yields: |
| A tuple of the name of the weight tensor and the tensor itself. |
| """ |
| yield from self.weights |
| self.weights = None |
|
|
|
|
| class CheckpointEngineWorker(Worker): |
| """CheckpointEngineWorker colocated with inference engine's WorkerProc on same GPU. |
| |
| Args: |
| rollout_config: The rollout configuration. |
| model_config: The model configuration. |
| server_adapter: The server adapter to update weights. |
| """ |
|
|
| def __init__( |
| self, |
| rollout_config: RolloutConfig, |
| model_config: HFModelConfig, |
| server_adapter: BaseRollout = None, |
| *args, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.rollout_config = rollout_config |
| self.model_config = model_config |
|
|
| self.server_adapter: BaseRollout = server_adapter |
| backend = self.rollout_config.checkpoint_engine.backend |
| bucket_size = self.rollout_config.checkpoint_engine.update_weights_bucket_megabytes << 20 |
| engine_kwargs = self.rollout_config.checkpoint_engine.engine_kwargs.get(backend, {}) |
| self.checkpoint_engine: CheckpointEngine = CheckpointEngineRegistry.new( |
| backend, bucket_size=bucket_size, **engine_kwargs |
| ) |
| self.extra_rollout_args = args |
| self.extra_rollout_kwargs = kwargs |
| if self.server_adapter is None: |
| self.server_adapter = get_rollout_class(self.rollout_config.name, self.rollout_config.mode)( |
| *self.extra_rollout_args, |
| config=self.rollout_config, |
| model_config=self.model_config, |
| device_mesh=None, |
| **self.extra_rollout_kwargs, |
| ) |
| |
| initialize_global_process_group_ray(timeout_second=None, backend="cpu:gloo") |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) |
| async def update_weights(self, global_steps: int = None): |
| weights = self.checkpoint_engine.receive_weights() |
| await self.server_adapter.update_weights(weights, global_steps=global_steps) |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) |
| def execute_checkpoint_engine(self, method: str, *args, **kwargs): |
| return getattr(self.checkpoint_engine, method)(*args, **kwargs) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def get_replica_rank(self) -> int: |
| """Get replica rank from the underlying rollout server adapter.""" |
| return self.server_adapter.replica_rank |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def is_leader_rank(self) -> bool: |
| """Get leader rank flag from the underlying rollout server adapter.""" |
| return self.server_adapter.is_leader_rank |
|
|
|
|
| _worker_cls = ray.remote(CheckpointEngineWorker) |
|
|
|
|
| class CheckpointEngineManager: |
| """Checkpoint engine manager to coordinate weight synchronization between trainer and rollout replicas. |
| |
| - ME: model engine, FSDP, MCore, VeOmni, export full tensor generator `get_per_tensor_param` |
| - CE: checkpoint engine, NCCL, NIXL, etc |
| |
| In trainer, model engine and checkpoint engine are in same process. |
| In rollout, checkpoint engine and rollout worker are in separate process, update weights via cuda ipc. |
| |
| ``` |
| ββββββββββ¬βββββββββ¬ββββββ¬βββββββββ βββββββββββββββββββββ¬ββββββββββββββββββββ |
| β ββββββ β ββββββ β β ββββββ β β Replica 0 β Replica 1 β |
| β β ME0β β β ME1β β β β MEnβ β ββββββ¬βββββ¬βββββ¬βββββΌβββββ¬βββββ¬βββββ¬βββββ€ |
| β ββββ¬ββ β ββββββ β ... β ββββββ β β 0 β 1 β 2 β 3 β 0 β 1 β 2 β 3 β |
| β v | | | | ββββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ |
| | ββββ΄ββ β ββββββ β β ββββββ β ^ ^ ^ cuda ipc ^ ^ ^ |
| β β CE β β β CE β β β β CE β β ββββ΄ββ¬βββ΄ββ¬βββ΄ββ¬βββ΄ββ¬βββ΄ββ¬βββ΄ββ¬βββ΄ββ¬βββ΄ββ |
| β ββββ¬ββ β ββββββ β β ββββββ β β CE β CE β CE β CE β CE β CE β CE β CE | |
| ββββββΌββββ΄βββββββββ΄ββββββ΄βββββββββ ββββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ΄βββ¬ββ |
| v | | | | | | | | |
| ββββββββββββββ(nccl/nixl/..)ββββββββββββββ΄βββββ΄βββββ΄βββββ΄βββββ΄βββββ΄βββββ΄βββββ |
| ``` |
| |
| Args: |
| config: The checkpoint engine config. |
| trainer: The trainer worker group. |
| replicas: The list of rollout replicas. |
| """ |
|
|
| def __init__( |
| self, |
| config: CheckpointEngineConfig, |
| trainer: RayWorkerGroup, |
| replicas: list[RolloutReplica], |
| ) -> None: |
| self.config = config |
| self.backend = config.backend |
| self.backend_cls = CheckpointEngineRegistry.get(config.backend) |
| self.trainer = trainer |
| self.replicas = replicas |
|
|
| def build_process_group(self, rollout: RayWorkerGroup): |
| """Build process group for trainer and rollout replicas.""" |
| trainer = self.trainer |
|
|
| |
| metadata = ray.get( |
| trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) |
| + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) |
| ) |
|
|
| |
| trainer_kwargs, rollout_kwargs = self.backend_cls.build_topology( |
| trainer.world_size, rollout.world_size, metadata |
| ) |
| for k, v in trainer_kwargs.items(): |
| assert len(v) == trainer.world_size, f"trainer_kwargs[{k}] must have length of {trainer.world_size}" |
| for k, v in rollout_kwargs.items(): |
| assert len(v) == rollout.world_size, f"rollout_kwargs[{k}] must have length of {rollout.world_size}" |
|
|
| trainer_kwargs["method"] = ["init_process_group"] * trainer.world_size |
| rollout_kwargs["method"] = ["init_process_group"] * rollout.world_size |
|
|
| |
| ray.get( |
| trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) |
| ) |
|
|
| def add_replicas(self, replicas: list[RolloutReplica]): |
| """Add rollout replicas to the manager for elastic scale up, will rebuild process group. |
| |
| Args: |
| replicas: The list of rollout replicas to add. |
| """ |
| self.replicas.extend(replicas) |
|
|
| def remove_replicas(self, replicas: list[RolloutReplica]): |
| """Remove rollout replicas from the manager for elastic scale down, will rebuild process group. |
| |
| Args: |
| replicas: The list of rollout replicas to remove. |
| """ |
| replicas_set = set(replicas) |
| self.replicas = [r for r in self.replicas if r not in replicas_set] |
|
|
| @auto_await |
| async def sleep_replicas(self): |
| """Sleep all rollout replicas: free weight and kv_cache device memory.""" |
| await asyncio.gather(*[r.sleep() for r in self.replicas]) |
|
|
| @auto_await |
| async def wake_up_replicas(self): |
| """Resume all rollout replicas: recover kv_cache and weights device memory.""" |
| await asyncio.gather(*[r.wake_up() for r in self.replicas]) |
|
|
| @auto_await |
| async def update_weights(self, global_steps: int = None): |
| """Update weights from trainer to rollout replicas. |
| |
| Args: |
| global_steps: The global steps of the trainer. |
| """ |
|
|
| |
| if self.backend == "naive": |
| ray.get(self.trainer.update_weights(global_steps=global_steps)) |
| return |
|
|
| |
| await asyncio.gather(*[r.abort_all_requests() for r in self.replicas]) |
|
|
| |
| workers = [] |
| for replica in self.replicas: |
| workers.extend(replica.workers) |
| rollout = RayWorkerGroup(worker_handles=workers, ray_cls_with_init=RayClassWithInitArgs(cls=_worker_cls)) |
| trainer = self.trainer |
|
|
| |
| await self.sleep_replicas() |
|
|
| |
| self.build_process_group(rollout) |
|
|
| |
| ray.get(trainer.update_weights(global_steps=global_steps) + rollout.update_weights(global_steps=global_steps)) |
|
|
| |
| ray.get( |
| trainer.execute_checkpoint_engine(["finalize"] * trainer.world_size) |
| + rollout.execute_checkpoint_engine(["finalize"] * rollout.world_size) |
| ) |
|
|
| |
| await self.wake_up_replicas() |
|
|
| |
| await asyncio.gather(*[r.resume_generation() for r in self.replicas]) |
|
|