Spaces:
Running
Running
| import shutil | |
| from time import sleep | |
| import pytest | |
| import numpy as np | |
| import tempfile | |
| import torch | |
| from ding.data.model_loader import FileModelLoader | |
| from ding.data.storage_loader import FileStorageLoader | |
| from ding.framework import task | |
| from ding.framework.context import OnlineRLContext | |
| from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger | |
| from ding.framework.parallel import Parallel | |
| from ding.utils.default_helper import set_pkg_seed | |
| from os import path | |
| def context_exchanger_main(): | |
| with task.start(ctx=OnlineRLContext()): | |
| if task.router.node_id == 0: | |
| task.add_role(task.role.LEARNER) | |
| elif task.router.node_id == 1: | |
| task.add_role(task.role.COLLECTOR) | |
| task.use(ContextExchanger(skip_n_iter=1)) | |
| if task.has_role(task.role.LEARNER): | |
| def learner_context(ctx: OnlineRLContext): | |
| assert len(ctx.trajectories) == 2 | |
| assert len(ctx.trajectory_end_idx) == 4 | |
| assert len(ctx.episodes) == 8 | |
| assert ctx.env_step > 0 | |
| assert ctx.env_episode > 0 | |
| yield | |
| ctx.train_iter += 1 | |
| task.use(learner_context) | |
| elif task.has_role(task.role.COLLECTOR): | |
| def collector_context(ctx: OnlineRLContext): | |
| if ctx.total_step > 0: | |
| assert ctx.train_iter > 0 | |
| yield | |
| ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] | |
| ctx.trajectory_end_idx = [1 for _ in range(4)] | |
| ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] | |
| ctx.env_step += 1 | |
| ctx.env_episode += 1 | |
| task.use(collector_context) | |
| task.run(max_step=3) | |
| def test_context_exchanger(): | |
| Parallel.runner(n_parallel_workers=2)(context_exchanger_main) | |
| def context_exchanger_with_storage_loader_main(): | |
| with task.start(ctx=OnlineRLContext()): | |
| if task.router.node_id == 0: | |
| task.add_role(task.role.LEARNER) | |
| elif task.router.node_id == 1: | |
| task.add_role(task.role.COLLECTOR) | |
| tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") | |
| storage_loader = FileStorageLoader(dirname=tempdir) | |
| try: | |
| task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader)) | |
| if task.has_role(task.role.LEARNER): | |
| def learner_context(ctx: OnlineRLContext): | |
| assert len(ctx.trajectories) == 2 | |
| assert len(ctx.trajectory_end_idx) == 4 | |
| assert len(ctx.episodes) == 8 | |
| assert ctx.env_step > 0 | |
| assert ctx.env_episode > 0 | |
| yield | |
| ctx.train_iter += 1 | |
| task.use(learner_context) | |
| elif task.has_role(task.role.COLLECTOR): | |
| def collector_context(ctx: OnlineRLContext): | |
| if ctx.total_step > 0: | |
| assert ctx.train_iter > 0 | |
| yield | |
| ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] | |
| ctx.trajectory_end_idx = [1 for _ in range(4)] | |
| ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] | |
| ctx.env_step += 1 | |
| ctx.env_episode += 1 | |
| task.use(collector_context) | |
| task.run(max_step=3) | |
| finally: | |
| storage_loader.shutdown() | |
| sleep(1) | |
| if path.exists(tempdir): | |
| shutil.rmtree(tempdir) | |
| def test_context_exchanger_with_storage_loader(): | |
| Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main) | |
| class MockPolicy: | |
| def __init__(self) -> None: | |
| self._model = self._get_model(10, 10) | |
| def _get_model(self, X_shape, y_shape) -> torch.nn.Module: | |
| return torch.nn.Sequential( | |
| torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(), | |
| torch.nn.Linear(24, y_shape) | |
| ) | |
| def train(self, X, y): | |
| loss_fn = torch.nn.MSELoss(reduction="mean") | |
| optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01) | |
| y_pred = self._model(X) | |
| loss = loss_fn(y_pred, y) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| def predict(self, X): | |
| with torch.no_grad(): | |
| return self._model(X) | |
| def model_exchanger_main(): | |
| with task.start(ctx=OnlineRLContext()): | |
| set_pkg_seed(0, use_cuda=False) | |
| policy = MockPolicy() | |
| X = torch.rand(10) | |
| y = torch.rand(10) | |
| if task.router.node_id == 0: | |
| task.add_role(task.role.LEARNER) | |
| else: | |
| task.add_role(task.role.COLLECTOR) | |
| task.use(ModelExchanger(policy._model)) | |
| if task.has_role(task.role.LEARNER): | |
| def train(ctx): | |
| policy.train(X, y) | |
| sleep(0.3) | |
| task.use(train) | |
| else: | |
| y_pred1 = policy.predict(X) | |
| def pred(ctx): | |
| if ctx.total_step > 0: | |
| y_pred2 = policy.predict(X) | |
| # Ensure model is upgraded | |
| assert any(y_pred1 != y_pred2) | |
| sleep(0.3) | |
| task.use(pred) | |
| task.run(2) | |
| def test_model_exchanger(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main) | |
| def model_exchanger_main_with_model_loader(): | |
| with task.start(ctx=OnlineRLContext()): | |
| set_pkg_seed(0, use_cuda=False) | |
| policy = MockPolicy() | |
| X = torch.rand(10) | |
| y = torch.rand(10) | |
| if task.router.node_id == 0: | |
| task.add_role(task.role.LEARNER) | |
| else: | |
| task.add_role(task.role.COLLECTOR) | |
| tempdir = path.join(tempfile.gettempdir(), "test_model_loader") | |
| model_loader = FileModelLoader(policy._model, dirname=tempdir) | |
| task.use(ModelExchanger(policy._model, model_loader=model_loader)) | |
| try: | |
| if task.has_role(task.role.LEARNER): | |
| def train(ctx): | |
| policy.train(X, y) | |
| sleep(0.3) | |
| task.use(train) | |
| else: | |
| y_pred1 = policy.predict(X) | |
| def pred(ctx): | |
| if ctx.total_step > 0: | |
| y_pred2 = policy.predict(X) | |
| # Ensure model is upgraded | |
| assert any(y_pred1 != y_pred2) | |
| sleep(0.3) | |
| task.use(pred) | |
| task.run(2) | |
| finally: | |
| model_loader.shutdown() | |
| sleep(0.3) | |
| if path.exists(tempdir): | |
| shutil.rmtree(tempdir) | |
| def test_model_exchanger_with_model_loader(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) | |
| def periodical_model_exchanger_main(): | |
| with task.start(ctx=OnlineRLContext()): | |
| set_pkg_seed(0, use_cuda=False) | |
| policy = MockPolicy() | |
| X = torch.rand(10) | |
| y = torch.rand(10) | |
| if task.router.node_id == 0: | |
| task.add_role(task.role.LEARNER) | |
| task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3)) | |
| else: | |
| task.add_role(task.role.COLLECTOR) | |
| task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3)) | |
| if task.has_role(task.role.LEARNER): | |
| def train(ctx): | |
| policy.train(X, y) | |
| sleep(0.3) | |
| task.use(train) | |
| else: | |
| y_pred1 = policy.predict(X) | |
| print("y_pred1: ", y_pred1) | |
| stale = 1 | |
| def pred(ctx): | |
| nonlocal stale | |
| y_pred2 = policy.predict(X) | |
| print("y_pred2: ", y_pred2) | |
| stale += 1 | |
| assert stale <= 3 or all(y_pred1 == y_pred2) | |
| if any(y_pred1 != y_pred2): | |
| stale = 1 | |
| sleep(0.3) | |
| task.use(pred) | |
| task.run(8) | |
| def test_periodical_model_exchanger(): | |
| Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main) | |