| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | the class of WorkerGroup |
| | """ |
| |
|
| | import logging |
| | import signal |
| | import threading |
| | import time |
| | from typing import Any, Callable, Dict, List, Optional |
| |
|
| | from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn |
| |
|
| |
|
| | class ResourcePool: |
| | """The resource pool with meta info such as world size.""" |
| |
|
| | def __init__( |
| | self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8 |
| | ) -> None: |
| | if process_on_nodes is None: |
| | process_on_nodes = [] |
| |
|
| | self._store = process_on_nodes |
| | self.max_colocate_count = max_colocate_count |
| | self.n_gpus_per_node = n_gpus_per_node |
| |
|
| | def add_node(self, process_count): |
| | self._store.append(process_count) |
| |
|
| | @property |
| | def world_size(self): |
| | return sum(self._store) |
| |
|
| | def __call__(self) -> Any: |
| | return self._store |
| |
|
| | @property |
| | def store(self): |
| | return self._store |
| |
|
| | def local_world_size_list(self) -> List[int]: |
| | nested_local_world_size_list = [ |
| | [local_world_size for _ in range(local_world_size)] for local_world_size in self._store |
| | ] |
| | return [item for row in nested_local_world_size_list for item in row] |
| |
|
| | def local_rank_list(self) -> List[int]: |
| | nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] |
| | return [item for row in nested_local_rank_list for item in row] |
| |
|
| |
|
| | class ClassWithInitArgs: |
| | """ |
| | This class stores a class constructor and the args/kwargs to construct the class. |
| | It is used to instantiate the remote class. |
| | """ |
| |
|
| | def __init__(self, cls, *args, **kwargs) -> None: |
| | self.cls = cls |
| | self.args = args |
| | self.kwargs = kwargs |
| |
|
| | def __call__(self) -> Any: |
| | return self.cls(*self.args, **self.kwargs) |
| |
|
| |
|
| | def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None: |
| | while True: |
| | for worker in workers: |
| | if not is_alive(worker): |
| | logging.warning(f"Worker {worker} is not alive, sending signal to main thread") |
| | signal.raise_signal(signal.SIGABRT) |
| |
|
| | time.sleep(gap_time) |
| |
|
| |
|
| | class WorkerGroup: |
| | """A group of workers""" |
| |
|
| | def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: |
| | self._is_init_with_detached_workers = True if resource_pool is None else False |
| |
|
| | if resource_pool is not None: |
| | |
| | self._procecss_dispatch_config = resource_pool() |
| | else: |
| | self._procecss_dispatch_config = None |
| |
|
| | self._workers = [] |
| | self._worker_names = [] |
| |
|
| | self._master_addr = None |
| | self._master_port = None |
| |
|
| | self._checker_thread: threading.Thread = None |
| |
|
| | def _is_worker_alive(self, worker): |
| | raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") |
| |
|
| | def _block_until_all_workers_alive(self) -> None: |
| | while True: |
| | all_state = [self._is_worker_alive(worker) for worker in self._workers] |
| | if False in all_state: |
| | time.sleep(1) |
| | else: |
| | break |
| |
|
| | def start_worker_aliveness_check(self, every_n_seconds=1) -> None: |
| | |
| | self._block_until_all_workers_alive() |
| |
|
| | self._checker_thread = threading.Thread( |
| | target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) |
| | ) |
| | self._checker_thread.start() |
| |
|
| | @property |
| | def world_size(self): |
| | return len(self._workers) |
| |
|
| | def _bind_worker_method(self, user_defined_cls, func_generator): |
| | """ |
| | Bind the worker method to the WorkerGroup |
| | """ |
| | for method_name in dir(user_defined_cls): |
| | try: |
| | method = getattr(user_defined_cls, method_name) |
| | assert callable(method), f"{method_name} in {user_defined_cls} is not callable" |
| | except Exception: |
| | |
| | continue |
| |
|
| | if hasattr(method, MAGIC_ATTR): |
| | |
| | attribute = getattr(method, MAGIC_ATTR) |
| | assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" |
| | assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" |
| |
|
| | dispatch_mode = attribute["dispatch_mode"] |
| | execute_mode = attribute["execute_mode"] |
| | blocking = attribute["blocking"] |
| |
|
| | |
| | if isinstance(dispatch_mode, Dispatch): |
| | |
| | fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) |
| | dispatch_fn = fn["dispatch_fn"] |
| | collect_fn = fn["collect_fn"] |
| | else: |
| | assert isinstance(dispatch_mode, dict) |
| | assert "dispatch_fn" in dispatch_mode |
| | assert "collect_fn" in dispatch_mode |
| | dispatch_fn = dispatch_mode["dispatch_fn"] |
| | collect_fn = dispatch_mode["collect_fn"] |
| |
|
| | |
| | execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) |
| | wg_execute_fn_name = execute_mode["execute_fn_name"] |
| |
|
| | |
| | try: |
| | execute_fn = getattr(self, wg_execute_fn_name) |
| | assert callable(execute_fn), "execute_fn must be callable" |
| | except Exception: |
| | print(f"execute_fn {wg_execute_fn_name} is invalid") |
| | raise |
| |
|
| | |
| | func = func_generator( |
| | self, |
| | method_name, |
| | dispatch_fn=dispatch_fn, |
| | collect_fn=collect_fn, |
| | execute_fn=execute_fn, |
| | blocking=blocking, |
| | ) |
| |
|
| | try: |
| | setattr(self, method_name, func) |
| | except Exception: |
| | raise ValueError(f"Fail to set method_name {method_name}") |
| |
|