Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| import copy | |
| from unittest.mock import patch | |
| from ding.framework import OnlineRLContext, task | |
| from ding.framework.middleware import TransitionList, inferencer, rolloutor | |
| from ding.framework.middleware import StepCollector, EpisodeCollector | |
| from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG | |
| def test_inferencer(): | |
| ctx = OnlineRLContext() | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| inferencer(0, policy, env)(ctx) | |
| assert isinstance(ctx.inference_output, dict) | |
| assert ctx.inference_output[0] == {'action': torch.Tensor([0.])} # sum of zeros([2, 2]) | |
| assert ctx.inference_output[1] == {'action': torch.Tensor([4.])} # sum of ones([2, 2]) | |
| def test_rolloutor(): | |
| ctx = OnlineRLContext() | |
| transitions = TransitionList(2) | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| for _ in range(10): | |
| inferencer(0, policy, env)(ctx) | |
| rolloutor(policy, env, transitions)(ctx) | |
| assert ctx.env_episode == 20 # 10 * env_num | |
| assert ctx.env_step == 20 # 10 * env_num | |
| def test_step_collector(): | |
| cfg = copy.deepcopy(CONFIG) | |
| ctx = OnlineRLContext() | |
| # test no random_collect_size | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| with task.start(): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| collector = StepCollector(cfg, policy, env) | |
| collector(ctx) | |
| assert len(ctx.trajectories) == 16 | |
| assert ctx.trajectory_end_idx == [7, 15] | |
| # test with random_collect_size | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| with task.start(): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| collector = StepCollector(cfg, policy, env, random_collect_size=8) | |
| collector(ctx) | |
| assert len(ctx.trajectories) == 16 | |
| assert ctx.trajectory_end_idx == [7, 15] | |
| def test_episode_collector(): | |
| cfg = copy.deepcopy(CONFIG) | |
| ctx = OnlineRLContext() | |
| # test no random_collect_size | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| with task.start(): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| collector = EpisodeCollector(cfg, policy, env) | |
| collector(ctx) | |
| assert len(ctx.episodes) == 16 | |
| # test with random_collect_size | |
| with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv): | |
| with task.start(): | |
| policy = MockPolicy() | |
| env = MockEnv() | |
| collector = EpisodeCollector(cfg, policy, env, random_collect_size=8) | |
| collector(ctx) | |
| assert len(ctx.episodes) == 16 | |