| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The abstract base class defining the interface for model training engines. |
| """ |
|
|
| from abc import abstractmethod |
| from contextlib import nullcontext |
| from typing import Any, Callable, ContextManager, Generator, Optional |
|
|
| import torch |
| from tensordict import TensorDict |
|
|
| from verl.utils.device import get_device_name |
| from verl.utils.tensordict_utils import maybe_fix_3d_position_ids |
|
|
|
|
| class BaseEngine: |
| """ |
| Abstract base class defining the interface for model training engines. Interface is subject to |
| change before release. |
| |
| Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. |
| """ |
|
|
| def initialize(self): |
| """ |
| Instantiate or load the model, optimizer, and learning rate scheduler. |
| |
| Should prepare all components necessary for training or evaluation. |
| """ |
| raise NotImplementedError |
|
|
| @property |
| @abstractmethod |
| def is_param_offload_enabled(self) -> bool: |
| """Whether parameter offloading is enabled.""" |
| raise NotImplementedError |
|
|
| @property |
| @abstractmethod |
| def is_optimizer_offload_enabled(self) -> bool: |
| """Whether optimizer offloading is enabled.""" |
| raise NotImplementedError |
|
|
| def train_mode(self, **kwargs): |
| """ |
| Context manager entry for switching the engine and model into training mode. |
| |
| Usage: |
| with engine.train_mode(): |
| # runs in training mode |
| """ |
| raise NotImplementedError |
|
|
| def eval_mode(self, **kwargs): |
| """ |
| Context manager entry for switching the engine and model into evaluation mode. |
| |
| Usage: |
| with engine.eval_mode(): |
| # runs in evaluation mode |
| """ |
| raise NotImplementedError |
|
|
| def optimizer_zero_grad(self): |
| """ |
| Zero the gradients of the optimizer. |
| """ |
| raise NotImplementedError |
|
|
| def optimizer_step(self): |
| """ |
| Perform an optimization step using the optimizer. |
| """ |
| raise NotImplementedError |
|
|
| def lr_scheduler_step(self): |
| """ |
| Advance the learning rate scheduler by one step. |
| |
| Returns: |
| current_lr (float or list[float]): Updated learning rate(s). |
| """ |
| raise NotImplementedError |
|
|
| def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any: |
| """ |
| Perform a forward pass and optionally a backward pass on a batch of data. |
| |
| Args: |
| data: The input data for the forward pass, typically containing tensors and metadata. |
| loss_function: The loss function to optimize. See `verl.workers.roles.utils.losses` for examples. |
| forward_only: If True, perform only the forward pass. If False, perform forward and backward pass. |
| |
| Returns: |
| Any: The output of the forward pass, which can be used for loss computation or other purposes. |
| """ |
| raise NotImplementedError |
|
|
| def train_batch(self, data: TensorDict, loss_function: Callable) -> Any: |
| """ |
| Perform a training step on a batch of data. |
| |
| Args: |
| data: The input data for training, typically containing tensors and metadata. |
| loss_function: A function that computes the loss and metrics given a batch and predictions. |
| |
| Returns: |
| dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the batch. |
| """ |
| maybe_fix_3d_position_ids(data) |
|
|
| self.optimizer_zero_grad() |
| outputs = self.forward_backward_batch(data, loss_function, forward_only=False) |
| grad_norm = self.optimizer_step() |
| if self.is_mp_src_rank_with_outputs(): |
| assert "grad_norm" not in outputs["metrics"] |
| outputs["metrics"]["grad_norm"] = grad_norm |
| return outputs |
|
|
| def infer_batch(self, data: TensorDict, loss_function: Optional[Callable] = None) -> Any: |
| """ |
| Perform inference on a batch of data. |
| |
| Args: |
| data: The input data for inference, typically containing tensors and metadata. |
| |
| Returns: |
| Any: The output of the inference, which can be used for predictions or other purposes. |
| """ |
| |
| maybe_fix_3d_position_ids(data) |
|
|
| with torch.no_grad(): |
| outputs = self.forward_backward_batch(data, loss_function, forward_only=True) |
| return outputs |
|
|
| def get_per_tensor_param(self) -> tuple[Generator[tuple[str, torch.Tensor], None, None], Optional[dict]]: |
| """ |
| Get a generator that yields per-tensor parameters and optional peft config. |
| |
| Returns: |
| Generator[tuple[str, torch.Tensor]]: A generator that yields tuples of parameter names and tensors. |
| Optional[dict]: Optional peft config. |
| """ |
| raise NotImplementedError |
|
|
| def get_data_parallel_size(self): |
| raise NotImplementedError |
|
|
| def get_data_parallel_rank(self): |
| raise NotImplementedError |
|
|
| def get_data_parallel_group(self): |
| raise NotImplementedError |
|
|
| def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): |
| """ |
| Move model parameters, optimizer states, or both to the specified device. |
| |
| Args: |
| device: Target device identifier. |
| model: If True, move the model. |
| optimizer: If True, move the optimizer states. |
| grad: If True, move the gradient buffer. |
| """ |
| if not model: |
| assert not optimizer and not grad, "Model must be moved to device along with optimizer and grad" |
|
|
| def save_checkpoint( |
| self, |
| local_path: str, |
| hdfs_path: Optional[str] = None, |
| global_step: int = 0, |
| max_ckpt_to_keep: Optional[int] = None, |
| **kwargs, |
| ) -> None: |
| """ |
| Save model, optimizer, and scheduler states to a checkpoint. |
| |
| Args: |
| local_path: Local filesystem path to save checkpoint. |
| hdfs_path: Optional HDFS path to copy checkpoint. |
| global_step: Integer training step number for naming. |
| max_ckpt_to_keep: Maximum number of recent checkpoints to retain. |
| **kwargs: Arbitrary keyword arguments. |
| """ |
| raise NotImplementedError |
|
|
| def load_checkpoint( |
| self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs |
| ) -> None: |
| """ |
| Load model, optimizer, and scheduler states from a checkpoint. |
| |
| Args: |
| local_path: Local filesystem path of the checkpoint. |
| hdfs_path: Optional HDFS path where checkpoint is stored. |
| del_local_after_load: Whether to delete local copy after loading. |
| **kwargs: Arbitrary keyword arguments. |
| """ |
| raise NotImplementedError |
|
|
| def is_mp_src_rank_with_outputs(self): |
| """ |
| Whether the current rank is the first rank in model parallel group that contains model outputs |
| """ |
| raise NotImplementedError |
|
|
| def disable_adapter(self) -> ContextManager: |
| """ |
| Disable all adapters temporarily under the context in the model for LoRA |
| """ |
| return nullcontext() |
|
|
|
|
| class BaseEngineCtx: |
| def __init__(self, engine: BaseEngine, mode, **kwargs): |
| """Base Engine context that handles load and offload |
| |
| Args: |
| engine: |
| **kwargs: |
| """ |
| self.engine = engine |
| self.mode = mode |
| assert self.mode in ("train", "eval") |
| self.disable_auto_offload = kwargs.pop("disable_auto_offload", False) |
|
|
| def _context_switch(self, device): |
| if self.disable_auto_offload: |
| return |
| if self.mode == "eval": |
| self.engine.to(device=device, model=self.engine.is_param_offload_enabled, optimizer=False, grad=False) |
| elif self.mode == "train": |
| self.engine.to( |
| device=device, |
| model=self.engine.is_param_offload_enabled, |
| optimizer=self.engine.is_optimizer_offload_enabled, |
| grad=self.engine.is_param_offload_enabled, |
| ) |
|
|
| def __enter__(self): |
| self._context_switch(get_device_name()) |
| self.engine.mode = self.mode |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self._context_switch("cpu") |
| self.engine.mode = None |
|
|
|
|
| class EngineRegistry: |
| """ |
| A registry for managing and instantiating different types of training engines. |
| |
| This class uses a dictionary to store engine classes, mapping a string key to each class. |
| It provides a decorator `register` to add new engines to the registry and a `new` method |
| to create an instance of a registered engine. |
| """ |
|
|
| _engines = {} |
|
|
| @classmethod |
| def register(cls, model_type: str, backend: list[str] | str, device: list[str] | str = "cuda"): |
| """ |
| A class method decorator that registers an engine class with a given key. |
| |
| This allows for dynamic instantiation of engine classes by their registered key. |
| |
| Args: |
| model_type (str): The type of the model |
| backend (list[str] | str): The backend to use for the model type |
| device (list[str] | str): The device type (e.g., "cuda", "npu", "cpu") this engine supports, |
| default is "cuda" |
| |
| Returns: |
| A decorator function that takes an engine class and registers it. |
| """ |
|
|
| def decorator(engine_class): |
| assert issubclass(engine_class, BaseEngine) |
| if model_type not in cls._engines: |
| cls._engines[model_type] = {} |
|
|
| backends = backend if isinstance(backend, list) else [backend] |
| devices = device if isinstance(device, list) else [device] |
| for current_backend in backends: |
| for current_device in devices: |
| if current_backend not in cls._engines[model_type]: |
| cls._engines[model_type][current_backend] = {} |
| if current_device not in cls._engines[model_type][current_backend]: |
| cls._engines[model_type][current_backend][current_device] = engine_class |
|
|
| return engine_class |
|
|
| return decorator |
|
|
| @classmethod |
| def get_engine_cls(cls, model_type: str, backend: str): |
| assert model_type in cls._engines, f"Unknown model_type: {model_type}" |
| assert backend in cls._engines[model_type], f"Unknown backend: {backend}" |
| device = get_device_name() |
| assert device in cls._engines[model_type][backend], ( |
| f"Unknown device: {device} for model_type: {model_type} and backend: {backend}" |
| ) |
| return cls._engines[model_type][backend][device] |
|
|
| @classmethod |
| def new(cls, model_type, backend, *args, **kwargs): |
| """ |
| Function to create a new training engine instance based on the provided config. |
| Args: |
| key: A configuration object containing the engine key and other settings. |
| *args: Variable length argument list. |
| **kwargs: Arbitrary keyword arguments. |
| Returns: |
| engine: An instance of the training engine corresponding to the config. |
| Raises: |
| NotImplementedError: If the engine key in the config does not match any known engines. |
| """ |
| engine_cls = cls.get_engine_cls(model_type, backend) |
| return engine_cls(*args, **kwargs) |
|
|