| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Test file to ensure that in general certain situational setups for notebooks work. |
| """ |
|
|
| import os |
| import time |
|
|
| from pytest import mark, raises |
| from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
|
|
| from accelerate import PartialState, notebook_launcher |
| from accelerate.test_utils import require_bnb |
| from accelerate.utils import is_bnb_available, is_xpu_available |
|
|
|
|
| def basic_function(): |
| |
| print(f"PartialState:\n{PartialState()}") |
|
|
|
|
| def tough_nut_function(queue): |
| if queue.empty(): |
| return |
| trial = queue.get() |
| if trial > 0: |
| queue.put(trial - 1) |
| raise RuntimeError("The nut hasn't cracked yet! Try again.") |
|
|
| print(f"PartialState:\n{PartialState()}") |
|
|
|
|
| def bipolar_sleep_function(sleep_sec: int): |
| state = PartialState() |
| if state.process_index % 2 == 0: |
| raise RuntimeError("I'm an even process. I don't like to sleep.") |
| else: |
| time.sleep(sleep_sec) |
|
|
|
|
| NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1)) |
|
|
|
|
| def test_can_initialize(): |
| notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) |
|
|
|
|
| @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test static rendezvous backends") |
| def test_static_rdzv_backend(): |
| notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="static") |
|
|
|
|
| @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test c10d rendezvous backends") |
| def test_c10d_rdzv_backend(): |
| notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES, rdzv_backend="c10d") |
|
|
|
|
| @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance") |
| def test_fault_tolerant(max_restarts: int = 3): |
| |
| import torch.multiprocessing as mp |
|
|
| |
| if is_xpu_available(): |
| ctx = mp.get_context("spawn") |
| else: |
| ctx = mp.get_context("fork") |
| queue = ctx.Queue() |
| queue.put(max_restarts) |
| notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts) |
|
|
|
|
| @mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test monitoring") |
| def test_monitoring(monitor_interval: float = 0.01, sleep_sec: int = 100): |
| start_time = time.time() |
| with raises(ChildFailedError, match="I'm an even process. I don't like to sleep."): |
| notebook_launcher( |
| bipolar_sleep_function, |
| (sleep_sec,), |
| num_processes=NUM_PROCESSES, |
| monitor_interval=monitor_interval, |
| ) |
| assert time.time() - start_time < sleep_sec, "Monitoring did not stop the process in time." |
|
|
|
|
| @require_bnb |
| def test_problematic_imports(): |
| with raises(RuntimeError, match="Please keep these imports"): |
| import bitsandbytes as bnb |
|
|
| notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) |
|
|
|
|
| def main(): |
| print("Test basic notebook can be ran") |
| test_can_initialize() |
| print("Test static rendezvous backend") |
| test_static_rdzv_backend() |
| print("Test c10d rendezvous backend") |
| test_c10d_rdzv_backend() |
| print("Test fault tolerant") |
| test_fault_tolerant() |
| print("Test monitoring") |
| test_monitoring() |
| if is_bnb_available(): |
| print("Test problematic imports (bnb)") |
| test_problematic_imports() |
| if NUM_PROCESSES > 1: |
| PartialState().destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|