Spaces:
Sleeping
Sleeping
| import random | |
| import time | |
| import socket | |
| import pytest | |
| import multiprocessing as mp | |
| from ditk import logging | |
| from ding.framework import task | |
| from ding.framework.parallel import Parallel | |
| from ding.framework.context import OnlineRLContext | |
| from ding.framework.middleware.barrier import Barrier | |
| PORTS_LIST = ["1235", "1236", "1237"] | |
| class EnvStepMiddleware: | |
| def __call__(self, ctx): | |
| yield | |
| ctx.env_step += 1 | |
| class SleepMiddleware: | |
| def __init__(self, node_id): | |
| self.node_id = node_id | |
| def random_sleep(self, diection, step): | |
| random.seed(self.node_id + step) | |
| sleep_second = random.randint(1, 5) | |
| logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second)) | |
| for i in range(sleep_second): | |
| time.sleep(1) | |
| print("Node:[{}] sleepping...".format(self.node_id)) | |
| logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection)) | |
| def __call__(self, ctx): | |
| self.random_sleep("forward", ctx.env_step) | |
| yield | |
| self.random_sleep("backward", ctx.env_step) | |
| def star_barrier(): | |
| with task.start(ctx=OnlineRLContext()): | |
| node_id = task.router.node_id | |
| if node_id == 0: | |
| attch_from_nums = 3 | |
| else: | |
| attch_from_nums = 0 | |
| barrier = Barrier(attch_from_nums) | |
| task.use(barrier, lock=False) | |
| task.use(SleepMiddleware(node_id), lock=False) | |
| task.use(barrier, lock=False) | |
| task.use(EnvStepMiddleware(), lock=False) | |
| try: | |
| task.run(2) | |
| except Exception as e: | |
| logging.error(e) | |
| assert False | |
| def mesh_barrier(): | |
| with task.start(ctx=OnlineRLContext()): | |
| node_id = task.router.node_id | |
| attch_from_nums = 3 - task.router.node_id | |
| barrier = Barrier(attch_from_nums) | |
| task.use(barrier, lock=False) | |
| task.use(SleepMiddleware(node_id), lock=False) | |
| task.use(barrier, lock=False) | |
| task.use(EnvStepMiddleware(), lock=False) | |
| try: | |
| task.run(2) | |
| except Exception as e: | |
| logging.error(e) | |
| assert False | |
| def unmatch_barrier(): | |
| with task.start(ctx=OnlineRLContext()): | |
| node_id = task.router.node_id | |
| attch_from_nums = 3 - task.router.node_id | |
| task.use(Barrier(attch_from_nums, 5), lock=False) | |
| if node_id != 2: | |
| task.use(Barrier(attch_from_nums, 5), lock=False) | |
| try: | |
| task.run(2) | |
| except TimeoutError as e: | |
| assert node_id != 2 | |
| logging.info("Node:[{}] timeout with barrier".format(node_id)) | |
| else: | |
| time.sleep(5) | |
| assert node_id == 2 | |
| logging.info("Node:[{}] finish barrier".format(node_id)) | |
| def launch_barrier(args): | |
| i, topo, fn, test_id = args | |
| address = socket.gethostbyname(socket.gethostname()) | |
| topology = "alone" | |
| attach_to = [] | |
| port_base = PORTS_LIST[test_id] | |
| port = port_base + str(i) | |
| if topo == 'star': | |
| if i != 0: | |
| attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)] | |
| elif topo == 'mesh': | |
| for j in range(i): | |
| attach_to.append('tcp://{}:{}{}'.format(address, port_base, j)) | |
| Parallel.runner( | |
| node_ids=i, | |
| ports=int(port), | |
| attach_to=attach_to, | |
| topology=topology, | |
| protocol="tcp", | |
| n_parallel_workers=1, | |
| startup_interval=0 | |
| )(fn) | |
| def test_star_topology_barrier(): | |
| ctx = mp.get_context("spawn") | |
| with ctx.Pool(processes=4) as pool: | |
| pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)]) | |
| pool.close() | |
| pool.join() | |
| def test_mesh_topology_barrier(): | |
| ctx = mp.get_context("spawn") | |
| with ctx.Pool(processes=4) as pool: | |
| pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)]) | |
| pool.close() | |
| pool.join() | |
| def test_unmatch_barrier(): | |
| ctx = mp.get_context("spawn") | |
| with ctx.Pool(processes=4) as pool: | |
| pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)]) | |
| pool.close() | |
| pool.join() | |