LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The abstract base class defining the interface for model training engines.
"""
from abc import abstractmethod
from contextlib import nullcontext
from typing import Any, Callable, ContextManager, Generator, Optional
import torch
from tensordict import TensorDict
from verl.utils.device import get_device_name
from verl.utils.tensordict_utils import maybe_fix_3d_position_ids
class BaseEngine:
"""
Abstract base class defining the interface for model training engines. Interface is subject to
change before release.
Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
"""
def initialize(self):
"""
Instantiate or load the model, optimizer, and learning rate scheduler.
Should prepare all components necessary for training or evaluation.
"""
raise NotImplementedError
@property
@abstractmethod
def is_param_offload_enabled(self) -> bool:
"""Whether parameter offloading is enabled."""
raise NotImplementedError
@property
@abstractmethod
def is_optimizer_offload_enabled(self) -> bool:
"""Whether optimizer offloading is enabled."""
raise NotImplementedError
def train_mode(self, **kwargs):
"""
Context manager entry for switching the engine and model into training mode.
Usage:
with engine.train_mode():
# runs in training mode
"""
raise NotImplementedError
def eval_mode(self, **kwargs):
"""
Context manager entry for switching the engine and model into evaluation mode.
Usage:
with engine.eval_mode():
# runs in evaluation mode
"""
raise NotImplementedError
def optimizer_zero_grad(self):
"""
Zero the gradients of the optimizer.
"""
raise NotImplementedError
def optimizer_step(self):
"""
Perform an optimization step using the optimizer.
"""
raise NotImplementedError
def lr_scheduler_step(self):
"""
Advance the learning rate scheduler by one step.
Returns:
current_lr (float or list[float]): Updated learning rate(s).
"""
raise NotImplementedError
def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:
"""
Perform a forward pass and optionally a backward pass on a batch of data.
Args:
data: The input data for the forward pass, typically containing tensors and metadata.
loss_function: The loss function to optimize. See `verl.workers.roles.utils.losses` for examples.
forward_only: If True, perform only the forward pass. If False, perform forward and backward pass.
Returns:
Any: The output of the forward pass, which can be used for loss computation or other purposes.
"""
raise NotImplementedError
def train_batch(self, data: TensorDict, loss_function: Callable) -> Any:
"""
Perform a training step on a batch of data.
Args:
data: The input data for training, typically containing tensors and metadata.
loss_function: A function that computes the loss and metrics given a batch and predictions.
Returns:
dict[str, torch.Tensor]: A dictionary containing the aggregated training metrics for the batch.
"""
maybe_fix_3d_position_ids(data)
self.optimizer_zero_grad()
outputs = self.forward_backward_batch(data, loss_function, forward_only=False)
grad_norm = self.optimizer_step()
if self.is_mp_src_rank_with_outputs():
assert "grad_norm" not in outputs["metrics"]
outputs["metrics"]["grad_norm"] = grad_norm
return outputs
def infer_batch(self, data: TensorDict, loss_function: Optional[Callable] = None) -> Any:
"""
Perform inference on a batch of data.
Args:
data: The input data for inference, typically containing tensors and metadata.
Returns:
Any: The output of the inference, which can be used for predictions or other purposes.
"""
# see comments from train_batch
maybe_fix_3d_position_ids(data)
with torch.no_grad():
outputs = self.forward_backward_batch(data, loss_function, forward_only=True)
return outputs
def get_per_tensor_param(self) -> tuple[Generator[tuple[str, torch.Tensor], None, None], Optional[dict]]:
"""
Get a generator that yields per-tensor parameters and optional peft config.
Returns:
Generator[tuple[str, torch.Tensor]]: A generator that yields tuples of parameter names and tensors.
Optional[dict]: Optional peft config.
"""
raise NotImplementedError
def get_data_parallel_size(self):
raise NotImplementedError
def get_data_parallel_rank(self):
raise NotImplementedError
def get_data_parallel_group(self):
raise NotImplementedError
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.
Args:
device: Target device identifier.
model: If True, move the model.
optimizer: If True, move the optimizer states.
grad: If True, move the gradient buffer.
"""
if not model:
assert not optimizer and not grad, "Model must be moved to device along with optimizer and grad"
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save model, optimizer, and scheduler states to a checkpoint.
Args:
local_path: Local filesystem path to save checkpoint.
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
**kwargs: Arbitrary keyword arguments.
"""
raise NotImplementedError
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: bool = True, **kwargs
) -> None:
"""
Load model, optimizer, and scheduler states from a checkpoint.
Args:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
**kwargs: Arbitrary keyword arguments.
"""
raise NotImplementedError
def is_mp_src_rank_with_outputs(self):
"""
Whether the current rank is the first rank in model parallel group that contains model outputs
"""
raise NotImplementedError
def disable_adapter(self) -> ContextManager:
"""
Disable all adapters temporarily under the context in the model for LoRA
"""
return nullcontext()
class BaseEngineCtx:
def __init__(self, engine: BaseEngine, mode, **kwargs):
"""Base Engine context that handles load and offload
Args:
engine:
**kwargs:
"""
self.engine = engine
self.mode = mode
assert self.mode in ("train", "eval")
self.disable_auto_offload = kwargs.pop("disable_auto_offload", False)
def _context_switch(self, device):
if self.disable_auto_offload:
return
if self.mode == "eval":
self.engine.to(device=device, model=self.engine.is_param_offload_enabled, optimizer=False, grad=False)
elif self.mode == "train":
self.engine.to(
device=device,
model=self.engine.is_param_offload_enabled,
optimizer=self.engine.is_optimizer_offload_enabled,
grad=self.engine.is_param_offload_enabled,
)
def __enter__(self):
self._context_switch(get_device_name())
self.engine.mode = self.mode
def __exit__(self, exc_type, exc_val, exc_tb):
self._context_switch("cpu")
self.engine.mode = None
class EngineRegistry:
"""
A registry for managing and instantiating different types of training engines.
This class uses a dictionary to store engine classes, mapping a string key to each class.
It provides a decorator `register` to add new engines to the registry and a `new` method
to create an instance of a registered engine.
"""
_engines = {}
@classmethod
def register(cls, model_type: str, backend: list[str] | str, device: list[str] | str = "cuda"):
"""
A class method decorator that registers an engine class with a given key.
This allows for dynamic instantiation of engine classes by their registered key.
Args:
model_type (str): The type of the model
backend (list[str] | str): The backend to use for the model type
device (list[str] | str): The device type (e.g., "cuda", "npu", "cpu") this engine supports,
default is "cuda"
Returns:
A decorator function that takes an engine class and registers it.
"""
def decorator(engine_class):
assert issubclass(engine_class, BaseEngine)
if model_type not in cls._engines:
cls._engines[model_type] = {}
backends = backend if isinstance(backend, list) else [backend]
devices = device if isinstance(device, list) else [device]
for current_backend in backends:
for current_device in devices:
if current_backend not in cls._engines[model_type]:
cls._engines[model_type][current_backend] = {}
if current_device not in cls._engines[model_type][current_backend]:
cls._engines[model_type][current_backend][current_device] = engine_class
return engine_class
return decorator
@classmethod
def get_engine_cls(cls, model_type: str, backend: str):
assert model_type in cls._engines, f"Unknown model_type: {model_type}"
assert backend in cls._engines[model_type], f"Unknown backend: {backend}"
device = get_device_name()
assert device in cls._engines[model_type][backend], (
f"Unknown device: {device} for model_type: {model_type} and backend: {backend}"
)
return cls._engines[model_type][backend][device]
@classmethod
def new(cls, model_type, backend, *args, **kwargs):
"""
Function to create a new training engine instance based on the provided config.
Args:
key: A configuration object containing the engine key and other settings.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
engine: An instance of the training engine corresponding to the config.
Raises:
NotImplementedError: If the engine key in the config does not match any known engines.
"""
engine_cls = cls.get_engine_cls(model_type, backend)
return engine_cls(*args, **kwargs)