| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from verl.single_controller.base.worker import Worker |
|
|
|
|
| class DPEngineWorker(Worker): |
|
|
| def __init__(self, *args, **kwargs): |
| |
| Worker.__init__(self, *args, **kwargs) |
|
|
| def init(self): |
| raise NotImplementedError |
|
|
| def add_engine(self, model, dp_config): |
| raise NotImplementedError |
|
|
| def execute_engine(self, method_name, *args, **kwargs): |
| print(f"execute_engine called with method={method_name}") |
| func = getattr(self._engine, method_name) |
| return func(*args, **kwargs) |
|
|
| def execute_module(self, method_name, *args, **kwargs): |
| print(f"execute_module called with method={method_name}") |
| func = getattr(self._engine.module, method_name) |
| return func(*args, **kwargs) |
|
|
| def get_model_size_on_rank_zero(self): |
| import torch |
| from verl.utils.model import get_model_size |
| if torch.distributed.get_rank() == 0: |
| |
| module_size, module_size_scale = get_model_size(self._model) |
| return module_size, module_size_scale |
| return None |
|
|