| | |
| | |
| | |
| | import faulthandler |
| | import logging |
| | import multiprocessing |
| | import sys |
| | import tempfile |
| | import threading |
| | import time |
| | import traceback |
| | import types |
| | import unittest |
| | from enum import Enum |
| | from functools import wraps |
| | from typing import NamedTuple |
| | from unittest import TestCase |
| |
|
| | import torch |
| | from torch.multiprocessing import active_children |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TestSkip(NamedTuple): |
| | exit_code: int |
| | message: str |
| |
|
| |
|
| | TEST_SKIPS = { |
| | 'backend_unavailable': |
| | TestSkip(10, 'Skipped because distributed backend is not available.'), |
| | 'no_cuda': |
| | TestSkip(11, 'CUDA is not available.'), |
| | 'multi-gpu-2': |
| | TestSkip(12, 'Need at least 2 CUDA device'), |
| | 'generic': |
| | TestSkip( |
| | 13, 'Test skipped at subprocess level, look at subprocess log for ' |
| | 'skip reason'), |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class MultiProcessTestCase(TestCase): |
| | MAIN_PROCESS_RANK = -1 |
| |
|
| | |
| | |
| | |
| | |
| | TEST_ERROR_EXIT_CODE = 10 |
| |
|
| | |
| | def _should_stop_test_suite(self) -> bool: |
| | return False |
| |
|
| | def prepare_subprocess(self): |
| | pass |
| |
|
| | @property |
| | def world_size(self) -> int: |
| | return 2 |
| |
|
| | @property |
| | def timeout(self) -> int: |
| | return 1000 |
| |
|
| | def join_or_run(self, fn): |
| |
|
| | @wraps(fn) |
| | def wrapper(self): |
| | if self.rank == self.MAIN_PROCESS_RANK: |
| | self._join_processes(fn) |
| | else: |
| | fn() |
| |
|
| | return types.MethodType(wrapper, self) |
| |
|
| | |
| | |
| | |
| | |
| | def __init__(self, method_name: str = 'runTest') -> None: |
| | super().__init__(method_name) |
| | fn = getattr(self, method_name) |
| | setattr(self, method_name, self.join_or_run(fn)) |
| |
|
| | def setUp(self) -> None: |
| | super().setUp() |
| | self.skip_return_code_checks = [] |
| | self.processes = [] |
| | self.rank = self.MAIN_PROCESS_RANK |
| | self.file_name = tempfile.NamedTemporaryFile(delete=False).name |
| | |
| | self.pid_to_pipe = {} |
| |
|
| | def tearDown(self) -> None: |
| | super().tearDown() |
| | for p in self.processes: |
| | p.terminate() |
| | |
| | |
| | |
| | |
| | self.processes = [] |
| |
|
| | def _current_test_name(self) -> str: |
| | |
| | |
| | return self.id().split('.')[-1] |
| |
|
| | def _start_processes(self, proc) -> None: |
| | self.processes = [] |
| | for rank in range(int(self.world_size)): |
| | parent_conn, child_conn = torch.multiprocessing.Pipe() |
| | process = proc( |
| | target=self.__class__._run, |
| | name='process ' + str(rank), |
| | args=(rank, self._current_test_name(), self.file_name, |
| | child_conn), |
| | ) |
| | process.start() |
| | self.pid_to_pipe[process.pid] = parent_conn |
| | self.processes.append(process) |
| |
|
| | def _spawn_processes(self) -> None: |
| | proc = torch.multiprocessing.get_context('spawn').Process |
| | self._start_processes(proc) |
| |
|
| | class Event(Enum): |
| | GET_TRACEBACK = 1 |
| |
|
| | @staticmethod |
| | def _event_listener(parent_pipe, signal_pipe, rank: int): |
| | while True: |
| | ready_pipes = multiprocessing.connection.wait( |
| | [parent_pipe, signal_pipe]) |
| |
|
| | if parent_pipe in ready_pipes: |
| |
|
| | if parent_pipe.closed: |
| | return |
| |
|
| | event = parent_pipe.recv() |
| |
|
| | if event == MultiProcessTestCase.Event.GET_TRACEBACK: |
| | |
| | with tempfile.NamedTemporaryFile(mode='r+') as tmp_file: |
| | faulthandler.dump_traceback(tmp_file) |
| | |
| | tmp_file.flush() |
| | tmp_file.seek(0) |
| | parent_pipe.send(tmp_file.read()) |
| |
|
| | if signal_pipe in ready_pipes: |
| | return |
| |
|
| | @classmethod |
| | def _run(cls, rank: int, test_name: str, file_name: str, |
| | parent_pipe) -> None: |
| | self = cls(test_name) |
| | try: |
| | self.prepare_subprocess() |
| | except Exception: |
| | raise sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) |
| | self.rank = rank |
| | self.file_name = file_name |
| | self.run_test(test_name, parent_pipe) |
| |
|
| | def run_test(self, test_name: str, parent_pipe) -> None: |
| | |
| | signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe( |
| | duplex=False) |
| | event_listener_thread = threading.Thread( |
| | target=MultiProcessTestCase._event_listener, |
| | args=(parent_pipe, signal_recv_pipe, self.rank), |
| | daemon=True, |
| | ) |
| | event_listener_thread.start() |
| |
|
| | |
| | |
| | try: |
| | getattr(self, test_name)() |
| | except unittest.SkipTest as se: |
| | logger.info(f'Process {self.rank} skipping test {test_name} for ' |
| | f'following reason: {str(se)}') |
| | sys.exit(TEST_SKIPS['generic'].exit_code) |
| | except Exception: |
| | logger.error( |
| | f'Caught exception: \n{traceback.format_exc()} exiting ' |
| | f'process {self.rank} with exit code: ' |
| | f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}') |
| | |
| | parent_pipe.send(traceback.format_exc()) |
| | sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) |
| | finally: |
| | if signal_send_pipe is not None: |
| | signal_send_pipe.send(None) |
| |
|
| | assert event_listener_thread is not None |
| | event_listener_thread.join() |
| | |
| | parent_pipe.close() |
| |
|
| | def _get_timedout_process_traceback(self) -> None: |
| | pipes = [] |
| | for i, process in enumerate(self.processes): |
| | if process.exitcode is None: |
| | pipe = self.pid_to_pipe[process.pid] |
| | try: |
| | pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) |
| | pipes.append((i, pipe)) |
| | except ConnectionError as e: |
| | logger.error( |
| | 'Encountered error while trying to get traceback ' |
| | f'for process {i}: {e}') |
| |
|
| | |
| | for rank, pipe in pipes: |
| | try: |
| | |
| | if pipe.poll(5): |
| | if pipe.closed: |
| | logger.info( |
| | f'Pipe closed for process {rank}, cannot retrieve ' |
| | 'traceback') |
| | continue |
| |
|
| | traceback = pipe.recv() |
| | logger.error(f'Process {rank} timed out with traceback: ' |
| | f'\n\n{traceback}') |
| | else: |
| | logger.error('Could not retrieve traceback for timed out ' |
| | f'process: {rank}') |
| | except ConnectionError as e: |
| | logger.error( |
| | 'Encountered error while trying to get traceback for ' |
| | f'process {rank}: {e}') |
| |
|
| | def _join_processes(self, fn) -> None: |
| | start_time = time.time() |
| | subprocess_error = False |
| | try: |
| | while True: |
| | |
| | for (i, p) in enumerate(self.processes): |
| | |
| | |
| | if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: |
| | print( |
| | f'Process {i} terminated with exit code ' |
| | f'{p.exitcode}, terminating remaining processes.') |
| | _active_children = active_children() |
| | for ac in _active_children: |
| | ac.terminate() |
| | subprocess_error = True |
| | break |
| | if subprocess_error: |
| | break |
| | |
| | |
| | if all([p.exitcode is not None for p in self.processes]): |
| | break |
| | |
| | |
| | elapsed = time.time() - start_time |
| | if elapsed > self.timeout: |
| | self._get_timedout_process_traceback() |
| | print(f'Timing out after {self.timeout} seconds and ' |
| | 'killing subprocesses.') |
| | for p in self.processes: |
| | p.terminate() |
| | break |
| | |
| | time.sleep(0.1) |
| |
|
| | elapsed_time = time.time() - start_time |
| |
|
| | if fn in self.skip_return_code_checks: |
| | self._check_no_test_errors(elapsed_time) |
| | else: |
| | self._check_return_codes(elapsed_time) |
| | finally: |
| | |
| | for pid, pipe in self.pid_to_pipe.items(): |
| | pipe.close() |
| |
|
| | def _check_no_test_errors(self, elapsed_time) -> None: |
| | """Checks that we didn't have any errors thrown in the child |
| | processes.""" |
| | for i, p in enumerate(self.processes): |
| | if p.exitcode is None: |
| | raise RuntimeError( |
| | 'Process {} timed out after {} seconds'.format( |
| | i, elapsed_time)) |
| | self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) |
| |
|
| | def _check_return_codes(self, elapsed_time) -> None: |
| | """Checks that the return codes of all spawned processes match, and |
| | skips tests if they returned a return code indicating a skipping |
| | condition.""" |
| | first_process = self.processes[0] |
| | |
| | |
| | |
| | |
| | |
| | |
| | errored_processes = [ |
| | (i, p) for i, p in enumerate(self.processes) |
| | if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE |
| | ] |
| | if errored_processes: |
| | error = '' |
| | for i, process in errored_processes: |
| | |
| | error_message = self.pid_to_pipe[process.pid].recv() |
| | error += ( |
| | 'Process {} exited with error code {} and exception:\n{}\n' |
| | .format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, |
| | error_message)) |
| | raise RuntimeError(error) |
| | |
| | |
| | for i, p in enumerate(self.processes): |
| | if p.exitcode is None: |
| | raise RuntimeError( |
| | f'Process {i} terminated or timed out after ' |
| | '{elapsed_time} seconds') |
| |
|
| | for skip in TEST_SKIPS.values(): |
| | if first_process.exitcode == skip.exit_code: |
| | raise unittest.SkipTest(skip.message) |
| |
|
| | |
| | |
| | |
| | |
| | self.skipTest(f'Skip test {self._testMethodName} due to ' |
| | 'the program abort') |
| |
|
| | @property |
| | def is_master(self) -> bool: |
| | return self.rank == 0 |
| |
|