Spaces:
Running
Running
| import multiprocessing as mp | |
| import pytest | |
| from threading import Lock | |
| from time import sleep, time | |
| import random | |
| import dataclasses | |
| from ding.framework import task, Context, Parallel | |
| class TestContext(Context): | |
| pipeline: list = dataclasses.field(default_factory=list) | |
| def test_serial_pipeline(): | |
| def step0(ctx): | |
| ctx.pipeline.append(0) | |
| def step1(ctx): | |
| ctx.pipeline.append(1) | |
| # Execute step1, step2 twice | |
| with task.start(ctx=TestContext()): | |
| for _ in range(2): | |
| task.forward(step0) | |
| task.forward(step1) | |
| assert task.ctx.pipeline == [0, 1, 0, 1] | |
| # Renew and execute step1, step2 | |
| task.renew() | |
| assert task.ctx.total_step == 1 | |
| task.forward(step0) | |
| task.forward(step1) | |
| assert task.ctx.pipeline == [0, 1] | |
| # Test context inheritance | |
| task.renew() | |
| def test_serial_yield_pipeline(): | |
| def step0(ctx): | |
| ctx.pipeline.append(0) | |
| yield | |
| ctx.pipeline.append(0) | |
| def step1(ctx): | |
| ctx.pipeline.append(1) | |
| with task.start(ctx=TestContext()): | |
| task.forward(step0) | |
| task.forward(step1) | |
| task.backward() | |
| assert task.ctx.pipeline == [0, 1, 0] | |
| assert len(task._backward_stack) == 0 | |
| def test_async_pipeline(): | |
| def step0(ctx): | |
| ctx.pipeline.append(0) | |
| def step1(ctx): | |
| ctx.pipeline.append(1) | |
| # Execute step1, step2 twice | |
| with task.start(async_mode=True, ctx=TestContext()): | |
| for _ in range(2): | |
| task.forward(step0) | |
| sleep(0.1) | |
| task.forward(step1) | |
| sleep(0.1) | |
| task.backward() | |
| assert task.ctx.pipeline == [0, 1, 0, 1] | |
| task.renew() | |
| assert task.ctx.total_step == 1 | |
| def test_async_yield_pipeline(): | |
| def step0(ctx): | |
| sleep(0.1) | |
| ctx.pipeline.append(0) | |
| yield | |
| ctx.pipeline.append(0) | |
| def step1(ctx): | |
| sleep(0.2) | |
| ctx.pipeline.append(1) | |
| with task.start(async_mode=True, ctx=TestContext()): | |
| task.forward(step0) | |
| task.forward(step1) | |
| sleep(0.3) | |
| task.backward().sync() | |
| assert task.ctx.pipeline == [0, 1, 0] | |
| assert len(task._backward_stack) == 0 | |
| def parallel_main(): | |
| sync_count = 0 | |
| def on_count(): | |
| nonlocal sync_count | |
| sync_count += 1 | |
| def counter(task): | |
| def _counter(ctx): | |
| sleep(0.2 + random.random() / 10) | |
| task.emit("count", only_remote=True) | |
| return _counter | |
| with task.start(): | |
| task.on("count", on_count) | |
| task.use(counter(task)) | |
| task.run(max_step=10) | |
| assert sync_count > 0 | |
| def test_parallel_pipeline(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) | |
| def test_emit(): | |
| with task.start(): | |
| greets = [] | |
| task.on("Greeting", lambda msg: greets.append(msg)) | |
| def step1(ctx): | |
| task.emit("Greeting", "Hi") | |
| task.use(step1) | |
| task.run(max_step=10) | |
| sleep(0.1) | |
| assert len(greets) == 10 | |
| def emit_remote_main(): | |
| with task.start(): | |
| greets = [] | |
| if task.router.node_id == 0: | |
| task.on("Greeting", lambda msg: greets.append(msg)) | |
| for _ in range(20): | |
| if greets: | |
| break | |
| sleep(0.1) | |
| assert len(greets) > 0 | |
| else: | |
| for _ in range(20): | |
| task.emit("Greeting", "Hi", only_remote=True) | |
| sleep(0.1) | |
| assert len(greets) == 0 | |
| def test_emit_remote(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main) | |
| def test_wait_for(): | |
| # Wait for will only work in async or parallel mode | |
| with task.start(async_mode=True, n_async_workers=2): | |
| greets = [] | |
| def step1(_): | |
| hi = task.wait_for("Greeting")[0][0] | |
| if hi: | |
| greets.append(hi) | |
| def step2(_): | |
| task.emit("Greeting", "Hi") | |
| task.use(step1) | |
| task.use(step2) | |
| task.run(max_step=10) | |
| assert len(greets) == 10 | |
| assert all(map(lambda hi: hi == "Hi", greets)) | |
| # Test timeout exception | |
| with task.start(async_mode=True, n_async_workers=2): | |
| def step1(_): | |
| task.wait_for("Greeting", timeout=0.3, ignore_timeout_exception=False) | |
| task.use(step1) | |
| with pytest.raises(TimeoutError): | |
| task.run(max_step=1) | |
| def test_async_exception(): | |
| with task.start(async_mode=True, n_async_workers=2): | |
| def step1(_): | |
| task.wait_for("any_event") # Never end | |
| def step2(_): | |
| sleep(0.3) | |
| raise Exception("Oh") | |
| task.use(step1) | |
| task.use(step2) | |
| with pytest.raises(Exception): | |
| task.run(max_step=2) | |
| assert task.ctx.total_step == 0 | |
| def early_stop_main(): | |
| with task.start(): | |
| task.use(lambda _: sleep(0.5)) | |
| if task.match_labels("node.0"): | |
| task.run(max_step=10) | |
| else: | |
| task.run(max_step=2) | |
| assert task.ctx.total_step < 7 | |
| def test_early_stop(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main) | |
| def test_parallel_in_sequencial(): | |
| result = [] | |
| def fast(_): | |
| result.append("fast") | |
| def slow(_): | |
| sleep(0.1) | |
| result.append("slow") | |
| with task.start(): | |
| task.use(lambda _: result.append("begin")) | |
| task.use(task.parallel(slow, fast)) | |
| task.run(max_step=1) | |
| assert result == ["begin", "fast", "slow"] | |
| def test_serial_in_parallel(): | |
| result = [] | |
| def fast(_): | |
| result.append("fast") | |
| def slow(_): | |
| sleep(0.1) | |
| result.append("slow") | |
| with task.start(async_mode=True): | |
| task.use(lambda _: result.append("begin")) | |
| task.use(task.serial(slow, fast)) | |
| task.run(max_step=1) | |
| assert result == ["begin", "slow", "fast"] | |
| def test_nested_middleware(): | |
| """ | |
| When there is a yield in the middleware, | |
| calling this middleware in another will lead to an unexpected result. | |
| Use task.forward or task.wrap can fix this problem. | |
| """ | |
| result = [] | |
| def child(): | |
| def _child(ctx: Context): | |
| result.append(3 * ctx.total_step) | |
| yield | |
| result.append(2 + 3 * ctx.total_step) | |
| return _child | |
| def mother(): | |
| _child = task.wrap(child()) | |
| def _mother(ctx: Context): | |
| child_back = _child(ctx) | |
| result.append(1 + 3 * ctx.total_step) | |
| child_back() | |
| return _mother | |
| with task.start(): | |
| task.use(mother()) | |
| task.run(2) | |
| assert result == [0, 1, 2, 3, 4, 5] | |
| def test_use_lock(): | |
| def slow(ctx): | |
| sleep(0.1) | |
| ctx.result = "slow" | |
| def fast(ctx): | |
| ctx.result = "fast" | |
| with task.start(async_mode=True): | |
| # The lock will turn async middleware into serial | |
| task.use(slow, lock=True) | |
| task.use(fast, lock=True) | |
| task.run(1) | |
| assert task.ctx.result == "fast" | |
| # With custom lock, it will not affect the inner lock of task | |
| lock = Lock() | |
| def slowest(ctx): | |
| sleep(0.3) | |
| ctx.result = "slowest" | |
| with task.start(async_mode=True): | |
| task.use(slow, lock=lock) | |
| # If it receives other locks, it will not be the last one to finish execution | |
| task.use(slowest, lock=True) | |
| task.use(fast, lock=lock) | |
| task.run(1) | |
| assert task.ctx.result == "slowest" | |
| def broadcast_finish_main(): | |
| with task.start(): | |
| def tick(ctx: Context): | |
| if task.router.node_id == 1 and ctx.total_step == 1: | |
| task.finish = True | |
| sleep(1) | |
| task.use(tick) | |
| task.run(20) | |
| def broadcast_main_target(): | |
| Parallel.runner( | |
| n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1 | |
| )(broadcast_finish_main) | |
| def broadcast_secondary_target(): | |
| "Start two standalone processes and connect to the main process." | |
| Parallel.runner( | |
| n_parallel_workers=2, | |
| protocol="tcp", | |
| address="127.0.0.1", | |
| topology="alone", | |
| ports=50556, | |
| attach_to=["tcp://127.0.0.1:50555"], | |
| node_ids=[1, 2], | |
| startup_interval=0.1 | |
| )(broadcast_finish_main) | |
| # gitlab ci and local test pass, github always fail | |
| def test_broadcast_finish(): | |
| start = time() | |
| ctx = mp.get_context("spawn") | |
| main_process = ctx.Process(target=broadcast_main_target) | |
| secondary_process = ctx.Process(target=broadcast_secondary_target) | |
| main_process.start() | |
| secondary_process.start() | |
| main_process.join() | |
| secondary_process.join() | |
| assert (time() - start) < 10 | |