|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import faulthandler |
|
|
import logging |
|
|
import multiprocessing |
|
|
import sys |
|
|
import tempfile |
|
|
import threading |
|
|
import time |
|
|
import traceback |
|
|
import types |
|
|
import unittest |
|
|
from enum import Enum |
|
|
from functools import wraps |
|
|
from typing import NamedTuple |
|
|
from unittest import TestCase |
|
|
|
|
|
import torch |
|
|
from torch.multiprocessing import active_children |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TestSkip(NamedTuple): |
|
|
exit_code: int |
|
|
message: str |
|
|
|
|
|
|
|
|
TEST_SKIPS = { |
|
|
'backend_unavailable': |
|
|
TestSkip(10, 'Skipped because distributed backend is not available.'), |
|
|
'no_cuda': |
|
|
TestSkip(11, 'CUDA is not available.'), |
|
|
'multi-gpu-2': |
|
|
TestSkip(12, 'Need at least 2 CUDA device'), |
|
|
'generic': |
|
|
TestSkip( |
|
|
13, 'Test skipped at subprocess level, look at subprocess log for ' |
|
|
'skip reason'), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiProcessTestCase(TestCase): |
|
|
MAIN_PROCESS_RANK = -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEST_ERROR_EXIT_CODE = 10 |
|
|
|
|
|
|
|
|
def _should_stop_test_suite(self) -> bool: |
|
|
return False |
|
|
|
|
|
def prepare_subprocess(self): |
|
|
pass |
|
|
|
|
|
@property |
|
|
def world_size(self) -> int: |
|
|
return 2 |
|
|
|
|
|
@property |
|
|
def timeout(self) -> int: |
|
|
return 1000 |
|
|
|
|
|
def join_or_run(self, fn): |
|
|
|
|
|
@wraps(fn) |
|
|
def wrapper(self): |
|
|
if self.rank == self.MAIN_PROCESS_RANK: |
|
|
self._join_processes(fn) |
|
|
else: |
|
|
fn() |
|
|
|
|
|
return types.MethodType(wrapper, self) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, method_name: str = 'runTest') -> None: |
|
|
super().__init__(method_name) |
|
|
fn = getattr(self, method_name) |
|
|
setattr(self, method_name, self.join_or_run(fn)) |
|
|
|
|
|
def setUp(self) -> None: |
|
|
super().setUp() |
|
|
self.skip_return_code_checks = [] |
|
|
self.processes = [] |
|
|
self.rank = self.MAIN_PROCESS_RANK |
|
|
self.file_name = tempfile.NamedTemporaryFile(delete=False).name |
|
|
|
|
|
self.pid_to_pipe = {} |
|
|
|
|
|
def tearDown(self) -> None: |
|
|
super().tearDown() |
|
|
for p in self.processes: |
|
|
p.terminate() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.processes = [] |
|
|
|
|
|
def _current_test_name(self) -> str: |
|
|
|
|
|
|
|
|
return self.id().split('.')[-1] |
|
|
|
|
|
def _start_processes(self, proc) -> None: |
|
|
self.processes = [] |
|
|
for rank in range(int(self.world_size)): |
|
|
parent_conn, child_conn = torch.multiprocessing.Pipe() |
|
|
process = proc( |
|
|
target=self.__class__._run, |
|
|
name='process ' + str(rank), |
|
|
args=(rank, self._current_test_name(), self.file_name, |
|
|
child_conn), |
|
|
) |
|
|
process.start() |
|
|
self.pid_to_pipe[process.pid] = parent_conn |
|
|
self.processes.append(process) |
|
|
|
|
|
def _spawn_processes(self) -> None: |
|
|
proc = torch.multiprocessing.get_context('spawn').Process |
|
|
self._start_processes(proc) |
|
|
|
|
|
class Event(Enum): |
|
|
GET_TRACEBACK = 1 |
|
|
|
|
|
@staticmethod |
|
|
def _event_listener(parent_pipe, signal_pipe, rank: int): |
|
|
while True: |
|
|
ready_pipes = multiprocessing.connection.wait( |
|
|
[parent_pipe, signal_pipe]) |
|
|
|
|
|
if parent_pipe in ready_pipes: |
|
|
|
|
|
if parent_pipe.closed: |
|
|
return |
|
|
|
|
|
event = parent_pipe.recv() |
|
|
|
|
|
if event == MultiProcessTestCase.Event.GET_TRACEBACK: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='r+') as tmp_file: |
|
|
faulthandler.dump_traceback(tmp_file) |
|
|
|
|
|
tmp_file.flush() |
|
|
tmp_file.seek(0) |
|
|
parent_pipe.send(tmp_file.read()) |
|
|
|
|
|
if signal_pipe in ready_pipes: |
|
|
return |
|
|
|
|
|
@classmethod |
|
|
def _run(cls, rank: int, test_name: str, file_name: str, |
|
|
parent_pipe) -> None: |
|
|
self = cls(test_name) |
|
|
try: |
|
|
self.prepare_subprocess() |
|
|
except Exception: |
|
|
raise sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) |
|
|
self.rank = rank |
|
|
self.file_name = file_name |
|
|
self.run_test(test_name, parent_pipe) |
|
|
|
|
|
def run_test(self, test_name: str, parent_pipe) -> None: |
|
|
|
|
|
signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe( |
|
|
duplex=False) |
|
|
event_listener_thread = threading.Thread( |
|
|
target=MultiProcessTestCase._event_listener, |
|
|
args=(parent_pipe, signal_recv_pipe, self.rank), |
|
|
daemon=True, |
|
|
) |
|
|
event_listener_thread.start() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
getattr(self, test_name)() |
|
|
except unittest.SkipTest as se: |
|
|
logger.info(f'Process {self.rank} skipping test {test_name} for ' |
|
|
f'following reason: {str(se)}') |
|
|
sys.exit(TEST_SKIPS['generic'].exit_code) |
|
|
except Exception: |
|
|
logger.error( |
|
|
f'Caught exception: \n{traceback.format_exc()} exiting ' |
|
|
f'process {self.rank} with exit code: ' |
|
|
f'{MultiProcessTestCase.TEST_ERROR_EXIT_CODE}') |
|
|
|
|
|
parent_pipe.send(traceback.format_exc()) |
|
|
sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) |
|
|
finally: |
|
|
if signal_send_pipe is not None: |
|
|
signal_send_pipe.send(None) |
|
|
|
|
|
assert event_listener_thread is not None |
|
|
event_listener_thread.join() |
|
|
|
|
|
parent_pipe.close() |
|
|
|
|
|
def _get_timedout_process_traceback(self) -> None: |
|
|
pipes = [] |
|
|
for i, process in enumerate(self.processes): |
|
|
if process.exitcode is None: |
|
|
pipe = self.pid_to_pipe[process.pid] |
|
|
try: |
|
|
pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK) |
|
|
pipes.append((i, pipe)) |
|
|
except ConnectionError as e: |
|
|
logger.error( |
|
|
'Encountered error while trying to get traceback ' |
|
|
f'for process {i}: {e}') |
|
|
|
|
|
|
|
|
for rank, pipe in pipes: |
|
|
try: |
|
|
|
|
|
if pipe.poll(5): |
|
|
if pipe.closed: |
|
|
logger.info( |
|
|
f'Pipe closed for process {rank}, cannot retrieve ' |
|
|
'traceback') |
|
|
continue |
|
|
|
|
|
traceback = pipe.recv() |
|
|
logger.error(f'Process {rank} timed out with traceback: ' |
|
|
f'\n\n{traceback}') |
|
|
else: |
|
|
logger.error('Could not retrieve traceback for timed out ' |
|
|
f'process: {rank}') |
|
|
except ConnectionError as e: |
|
|
logger.error( |
|
|
'Encountered error while trying to get traceback for ' |
|
|
f'process {rank}: {e}') |
|
|
|
|
|
def _join_processes(self, fn) -> None: |
|
|
start_time = time.time() |
|
|
subprocess_error = False |
|
|
try: |
|
|
while True: |
|
|
|
|
|
for (i, p) in enumerate(self.processes): |
|
|
|
|
|
|
|
|
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE: |
|
|
print( |
|
|
f'Process {i} terminated with exit code ' |
|
|
f'{p.exitcode}, terminating remaining processes.') |
|
|
_active_children = active_children() |
|
|
for ac in _active_children: |
|
|
ac.terminate() |
|
|
subprocess_error = True |
|
|
break |
|
|
if subprocess_error: |
|
|
break |
|
|
|
|
|
|
|
|
if all([p.exitcode is not None for p in self.processes]): |
|
|
break |
|
|
|
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
if elapsed > self.timeout: |
|
|
self._get_timedout_process_traceback() |
|
|
print(f'Timing out after {self.timeout} seconds and ' |
|
|
'killing subprocesses.') |
|
|
for p in self.processes: |
|
|
p.terminate() |
|
|
break |
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
if fn in self.skip_return_code_checks: |
|
|
self._check_no_test_errors(elapsed_time) |
|
|
else: |
|
|
self._check_return_codes(elapsed_time) |
|
|
finally: |
|
|
|
|
|
for pid, pipe in self.pid_to_pipe.items(): |
|
|
pipe.close() |
|
|
|
|
|
def _check_no_test_errors(self, elapsed_time) -> None: |
|
|
"""Checks that we didn't have any errors thrown in the child |
|
|
processes.""" |
|
|
for i, p in enumerate(self.processes): |
|
|
if p.exitcode is None: |
|
|
raise RuntimeError( |
|
|
'Process {} timed out after {} seconds'.format( |
|
|
i, elapsed_time)) |
|
|
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode) |
|
|
|
|
|
def _check_return_codes(self, elapsed_time) -> None: |
|
|
"""Checks that the return codes of all spawned processes match, and |
|
|
skips tests if they returned a return code indicating a skipping |
|
|
condition.""" |
|
|
first_process = self.processes[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
errored_processes = [ |
|
|
(i, p) for i, p in enumerate(self.processes) |
|
|
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE |
|
|
] |
|
|
if errored_processes: |
|
|
error = '' |
|
|
for i, process in errored_processes: |
|
|
|
|
|
error_message = self.pid_to_pipe[process.pid].recv() |
|
|
error += ( |
|
|
'Process {} exited with error code {} and exception:\n{}\n' |
|
|
.format(i, MultiProcessTestCase.TEST_ERROR_EXIT_CODE, |
|
|
error_message)) |
|
|
raise RuntimeError(error) |
|
|
|
|
|
|
|
|
for i, p in enumerate(self.processes): |
|
|
if p.exitcode is None: |
|
|
raise RuntimeError( |
|
|
f'Process {i} terminated or timed out after ' |
|
|
'{elapsed_time} seconds') |
|
|
|
|
|
for skip in TEST_SKIPS.values(): |
|
|
if first_process.exitcode == skip.exit_code: |
|
|
raise unittest.SkipTest(skip.message) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.skipTest(f'Skip test {self._testMethodName} due to ' |
|
|
'the program abort') |
|
|
|
|
|
@property |
|
|
def is_master(self) -> bool: |
|
|
return self.rank == 0 |
|
|
|