Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import concurrent | |
| import gc | |
| import multiprocessing | |
| import os | |
| import signal | |
| from tempfile import NamedTemporaryFile, _TemporaryFileWrapper | |
| from typing import Dict, List, Tuple | |
| import torch | |
| class SafeMpContext(multiprocessing.context.BaseContext): | |
| def __init__(self) -> None: | |
| self.mp_context = multiprocessing.get_context("spawn") | |
| self.processes: List[multiprocessing.context.SpawnProcess] = [] | |
| def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess: | |
| p = self.mp_context.Process(*args, **kwargs) | |
| p.daemon = True | |
| self.processes.append(p) | |
| return p | |
| def kill_all_processes(self): | |
| for p in self.processes: | |
| p.terminate() | |
| p.join(1) | |
| # (https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.exitcode) | |
| # Even though the python documentation seems to say that after joining the exitcode should | |
| # become set, this is not what we have observed in practice. We therefore loop until it | |
| # becomes set. | |
| while p.exitcode is None: | |
| p.kill() | |
| p.join() | |
| assert p.exitcode is not None, f"{p} is still alive" | |
| def log_bad_exit_codes(self): | |
| for rank, p in enumerate(self.processes): | |
| if p.exitcode == 0: | |
| continue | |
| if p.exitcode < 0: | |
| try: | |
| signal_desc = f" (signal {signal.Signals(-p.exitcode).name})" | |
| except ValueError: | |
| signal_desc = " (unrecognized signal)" | |
| else: | |
| signal_desc = "" | |
| print( | |
| f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}" | |
| ) | |
| def __getattr__(self, name: str): | |
| return getattr(self.mp_context, name) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.kill_all_processes() | |
| self.log_bad_exit_codes() | |
| def init_process_group(init_method: str, rank: int, world_size: int): | |
| torch._C._set_print_stack_traces_on_fatal_signal(True) | |
| if torch.cuda.device_count() >= world_size: | |
| backend = "nccl" | |
| torch.cuda.set_device(rank) | |
| else: | |
| # Use Gloo instead of NCCL so that we can run on a single GPU | |
| backend = "gloo" | |
| torch.distributed.init_process_group( | |
| backend=backend, | |
| world_size=world_size, | |
| rank=rank, | |
| init_method=init_method, | |
| ) | |
| def _launch_subprocesses_fn_wrapper( | |
| init_method: str, | |
| rank: int, | |
| world_size: int, | |
| parent_env_vars: Dict[str, str], | |
| user_fn, | |
| args, | |
| kwargs, | |
| ): | |
| # This function initializes the environment for spawned subprocesses by capturing and applying the current | |
| # environment variables from the parent process. By clearing and then updating `os.environ` with `parent_env_vars`, | |
| # we ensure that each spawned subprocess starts with an environment that mirrors the parent process at the time | |
| # of job submission. This approach guarantees consistency across subprocesses, reflecting the latest state of the | |
| # parent's environment variables even/especially when reusing the subprocesses for subsequent job executions. | |
| os.environ.clear() | |
| os.environ.update(parent_env_vars) | |
| # Check if the process group is already initialized | |
| if not torch.distributed.is_initialized(): | |
| init_process_group(init_method, rank, world_size) | |
| try: | |
| return user_fn(*args, **kwargs) | |
| finally: | |
| # should free all memory used by PyTorch in the subprocesses | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Global dictionary to keep track of executors and temporary files | |
| EXECUTORS_AND_FILES: Dict[ | |
| int, Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor] | |
| ] = {} | |
| def get_global_pool_allocator( | |
| world_size: int, | |
| ) -> Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]: | |
| global EXECUTORS_AND_FILES | |
| if world_size not in EXECUTORS_AND_FILES: | |
| rdv = NamedTemporaryFile(mode="w+b", buffering=-1, delete=False) | |
| mp_context = SafeMpContext() | |
| executor = concurrent.futures.ProcessPoolExecutor( | |
| max_workers=world_size, mp_context=mp_context | |
| ) | |
| # Add the executor and temporary file to the global list | |
| EXECUTORS_AND_FILES[world_size] = (rdv, executor) | |
| else: | |
| rdv, executor = EXECUTORS_AND_FILES[world_size] | |
| return rdv, executor | |
| class ProcessPoolExecutorManager: | |
| def __init__(self, world_size: int): | |
| self.world_size = world_size | |
| def __enter__(self): | |
| # when you start a subprocess you want to free memory used by PyTorch in the main process, | |
| # so the subprocess can have memory | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| self.rdv, self.executor = get_global_pool_allocator(self.world_size) | |
| return self | |
| def submit(self, fn, *args, **kwargs): | |
| return self.executor.submit(fn, *args, **kwargs) | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| # One of the subprocesses jobs has failed | |
| if exc_val: | |
| # We want to avoid killing the processes while the executor was thinking that they were | |
| # still up and healthy (as this may have unintended consequences, such as the executor | |
| # restarting the processes, or reporting spurious errors). | |
| # Set the internal state of the executor and call cancel() on each issued task that is | |
| # not executing | |
| self.executor.shutdown(wait=False, cancel_futures=True) | |
| # Kill all remaining subprocesses | |
| mp_context = self.executor._mp_context | |
| mp_context.kill_all_processes() | |
| mp_context.log_bad_exit_codes() | |
| # We want to wait for all the futures to complete, so we need to shutdown twice | |
| self.executor.shutdown(wait=True) | |
| # Close the temporary file | |
| self.rdv.close() | |
| # Remove the executor from the global list. | |
| # This will recreate it next time a test is requiring this world_size | |
| assert self.world_size in EXECUTORS_AND_FILES | |
| del EXECUTORS_AND_FILES[self.world_size] | |
| print( | |
| f"Shutdown and remove the executor after subprocesses error. Executors cnt: {len(EXECUTORS_AND_FILES)}" | |
| ) | |
| def launch_subprocesses(world_size: int, fn, *args, **kwargs): | |
| # This custom manager allows each test execution to enter/exit the following context. | |
| # When entering the context, it creates/reuses a new/existing ProcessPoolExecutor with the given world size. | |
| # The context also allows to detect an exception upon exit, in which case it will kill all spawned processes, | |
| # delete the manager, recreate the manager upon following request and respawn processes. | |
| with ProcessPoolExecutorManager(world_size) as manager: | |
| futures = [ | |
| manager.submit( | |
| _launch_subprocesses_fn_wrapper, | |
| init_method=f"file://{manager.rdv.name}", | |
| rank=rank, | |
| world_size=world_size, | |
| parent_env_vars=dict(os.environ), | |
| user_fn=fn, | |
| args=args, | |
| kwargs=kwargs, | |
| ) | |
| for rank in range(world_size) | |
| ] | |
| done, _ = concurrent.futures.wait( | |
| futures, return_when=concurrent.futures.FIRST_EXCEPTION | |
| ) | |
| for f in done: | |
| f.result() | |