|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Test file to ensure that in general certain situational setups for notebooks work. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
from multiprocessing import Queue |
|
|
|
|
|
from pytest import mark, raises |
|
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
|
|
|
|
|
from accelerate import PartialState, notebook_launcher |
|
|
from accelerate.test_utils import require_bnb |
|
|
from accelerate.utils import is_bnb_available |
|
|
|
|
|
|
|
|
def basic_function(): |
|
|
|
|
|
print(f"PartialState:\n{PartialState()}") |
|
|
|
|
|
|
|
|
def tough_nut_function(queue: Queue): |
|
|
if queue.empty(): |
|
|
return |
|
|
trial = queue.get() |
|
|
if trial > 0: |
|
|
queue.put(trial - 1) |
|
|
raise RuntimeError("The nut hasn't cracked yet! Try again.") |
|
|
|
|
|
print(f"PartialState:\n{PartialState()}") |
|
|
|
|
|
|
|
|
def bipolar_sleep_function(sleep_sec: int): |
|
|
state = PartialState() |
|
|
if state.process_index % 2 == 0: |
|
|
raise RuntimeError("I'm an even process. I don't like to sleep.") |
|
|
else: |
|
|
time.sleep(sleep_sec) |
|
|
|
|
|
|
|
|
NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1)) |
|
|
|
|
|
|
|
|
def test_can_initialize(): |
|
|
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) |
|
|
|
|
|
|
|
|
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends") |
|
|
def test_static_rdzv_backend(): |
|
|
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static") |
|
|
|
|
|
|
|
|
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends") |
|
|
def test_c10d_rdzv_backend(): |
|
|
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d") |
|
|
|
|
|
|
|
|
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance") |
|
|
def test_fault_tolerant(max_restarts: int = 3): |
|
|
queue = Queue() |
|
|
queue.put(max_restarts) |
|
|
notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts) |
|
|
|
|
|
|
|
|
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring") |
|
|
def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100): |
|
|
start_time = time.time() |
|
|
with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."): |
|
|
notebook_launcher( |
|
|
bipolar_sleep_function, |
|
|
(sleep_sec,), |
|
|
num_processes=NUM_PROCESSES, |
|
|
monitor_interval=monitor_interval, |
|
|
) |
|
|
assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time." |
|
|
|
|
|
|
|
|
@require_bnb |
|
|
def test_problematic_imports(): |
|
|
with raises(RuntimeError, match="Please keep these imports"): |
|
|
import bitsandbytes as bnb |
|
|
|
|
|
notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("Test basic notebook can be ran") |
|
|
test_can_initialize() |
|
|
print("Test static rendezvous backend") |
|
|
test_static_rdzv_backend() |
|
|
print("Test c10d rendezvous backend") |
|
|
test_c10d_rdzv_backend() |
|
|
print("Test fault tolerant") |
|
|
test_fault_tolerant() |
|
|
print("Test monitoring") |
|
|
test_monitoring() |
|
|
if is_bnb_available(): |
|
|
print("Test problematic imports (bnb)") |
|
|
test_problematic_imports() |
|
|
if NUM_PROCESSES > 1: |
|
|
PartialState().destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|