Spaces:
Running
Running
| import copy | |
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| from ding.policy.cql import CQLPolicy, DiscreteCQLPolicy | |
| from ding.utils.data import offline_data_save_type | |
| from tensorboardX import SummaryWriter | |
| from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, EpsGreedySampleWrapper, TargetNetworkWrapper | |
| import os | |
| from typing import List | |
| from collections import namedtuple | |
| from ding.utils import deep_merge_dicts | |
| obs_space = 5 | |
| action_space = 3 | |
| cfg1 = EasyDict(CQLPolicy.default_config()) | |
| cfg1.model.obs_shape = obs_space | |
| cfg1.model.action_shape = action_space | |
| cfg2 = copy.deepcopy(cfg1) | |
| cfg2.learn.auto_alpha = False | |
| cfg2.learn.log_space = False | |
| cfg3 = EasyDict(DiscreteCQLPolicy.default_config()) | |
| cfg3.model = {} | |
| cfg3.model.obs_shape = obs_space | |
| cfg3.model.action_shape = action_space | |
| cfg4 = copy.deepcopy(cfg3) | |
| cfg4.learn.auto_alpha = False | |
| def get_batch(size=8): | |
| data = {} | |
| for i in range(size): | |
| obs = torch.zeros(obs_space) | |
| data[i] = obs | |
| return data | |
| def get_transition(size=20): | |
| data = [] | |
| for i in range(size): | |
| sample = {} | |
| sample['obs'] = torch.zeros(obs_space) | |
| sample['action'] = torch.zeros(action_space) | |
| sample['done'] = False | |
| sample['next_obs'] = torch.zeros(obs_space) | |
| sample['reward'] = torch.Tensor([1.]) | |
| data.append(sample) | |
| return data | |
| def get_transition_batch(bs=1): | |
| sample = {} | |
| sample['obs'] = torch.zeros(bs, obs_space) | |
| sample['action'] = torch.zeros(bs, action_space) | |
| return sample | |
| def test_cql_continuous(cfg): | |
| policy = CQLPolicy(cfg, enable_field=['collect', 'eval', 'learn']) | |
| assert type(policy._target_model) == TargetNetworkWrapper | |
| q_value = policy._get_q_value(get_transition_batch(cfg.learn.num_actions)) | |
| assert q_value[0].shape[-1] == 1 and q_value[0].shape[-2] == cfg.learn.num_actions | |
| act, log_prob = policy._get_policy_actions(get_transition_batch(cfg.learn.num_actions)) | |
| assert list(act.shape) == [cfg.learn.num_actions * 10, action_space] | |
| sample = get_transition(size=20) | |
| out = policy._forward_learn(sample) | |
| def get_transition_discrete(size=20): | |
| data = [] | |
| for i in range(size): | |
| sample = {} | |
| sample['obs'] = torch.zeros(obs_space) | |
| sample['action'] = torch.tensor(i % action_space) | |
| sample['done'] = False | |
| sample['next_obs'] = torch.zeros(obs_space) | |
| sample['reward'] = torch.Tensor([1.]) | |
| data.append(sample) | |
| return data | |
| def test_cql_discrete(cfg): | |
| policy = DiscreteCQLPolicy(cfg, enable_field=['collect', 'eval', 'learn']) | |
| assert type(policy._learn_model) == ArgmaxSampleWrapper | |
| assert type(policy._target_model) == TargetNetworkWrapper | |
| assert type(policy._collect_model) == EpsGreedySampleWrapper | |
| sample = get_transition_batch(bs=20) | |
| samples = policy._get_train_sample(sample) | |
| assert len(samples['obs']) == 20 | |
| state = policy._state_dict_learn() | |
| policy._load_state_dict_learn(state) | |
| sample = get_transition_discrete(size=1) | |
| out = policy._forward_learn(sample) | |
| out = policy._forward_collect(get_batch(size=8), eps=0.1) | |