Spaces:
Running
Running
| from os import path | |
| import os | |
| import copy | |
| from easydict import EasyDict | |
| from collections import deque | |
| import pytest | |
| import shutil | |
| import wandb | |
| import h5py | |
| import torch.nn as nn | |
| from unittest.mock import MagicMock | |
| from unittest.mock import Mock, patch | |
| from ding.utils import DistributedWriter | |
| from ding.framework.middleware.tests import MockPolicy, CONFIG | |
| from ding.framework import OnlineRLContext, OfflineRLContext | |
| from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger, \ | |
| wandb_offline_logger | |
| test_folder = "test_exp" | |
| test_path = path.join(os.getcwd(), test_folder) | |
| cfg = EasyDict({"exp_name": "test_exp"}) | |
| def get_online_ctx(): | |
| ctx = OnlineRLContext() | |
| ctx.eval_value = -10000 | |
| ctx.train_iter = 34 | |
| ctx.env_step = 78 | |
| ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} | |
| return ctx | |
| def online_ctx_output_dict(): | |
| ctx = get_online_ctx() | |
| return ctx | |
| def online_ctx_output_deque(): | |
| ctx = get_online_ctx() | |
| ctx.train_output = deque([ctx.train_output]) | |
| return ctx | |
| def online_ctx_output_list(): | |
| ctx = get_online_ctx() | |
| ctx.train_output = [ctx.train_output] | |
| return ctx | |
| def online_scalar_ctx(): | |
| ctx = get_online_ctx() | |
| ctx.train_output = {'[scalars]': 1} | |
| return ctx | |
| class MockOnlineWriter: | |
| def __init__(self): | |
| self.ctx = get_online_ctx() | |
| def add_scalar(self, tag, scalar_value, global_step): | |
| if tag in ['basic/eval_episode_return_mean-env_step', 'basic/eval_episode_return_mean']: | |
| assert scalar_value == self.ctx.eval_value | |
| assert global_step == self.ctx.env_step | |
| elif tag == 'basic/eval_episode_return_mean-train_iter': | |
| assert scalar_value == self.ctx.eval_value | |
| assert global_step == self.ctx.train_iter | |
| elif tag in ['basic/train_td_error-env_step', 'basic/train_td_error']: | |
| assert scalar_value == self.ctx.train_output['td_error'] | |
| assert global_step == self.ctx.env_step | |
| elif tag == 'basic/train_td_error-train_iter': | |
| assert scalar_value == self.ctx.train_output['td_error'] | |
| assert global_step == self.ctx.train_iter | |
| else: | |
| raise NotImplementedError('tag should be in the tags defined') | |
| def add_histogram(self, tag, values, global_step): | |
| assert tag == 'test_histogram' | |
| assert values == [1, 2, 3, 4, 5, 6] | |
| assert global_step in [self.ctx.train_iter, self.ctx.env_step] | |
| def close(self): | |
| pass | |
| def mock_get_online_instance(): | |
| return MockOnlineWriter() | |
| class TestOnlineLogger: | |
| def test_online_logger_output_dict(self, online_ctx_output_dict): | |
| with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
| online_logger()(online_ctx_output_dict) | |
| def test_online_logger_record_output_dict(self, online_ctx_output_dict): | |
| with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
| online_logger(record_train_iter=True)(online_ctx_output_dict) | |
| def test_online_logger_record_output_deque(self, online_ctx_output_deque): | |
| with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance): | |
| online_logger()(online_ctx_output_deque) | |
| def get_offline_ctx(): | |
| ctx = OfflineRLContext() | |
| ctx.eval_value = -10000000000 | |
| ctx.train_iter = 3333 | |
| ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15} | |
| return ctx | |
| def offline_ctx_output_dict(): | |
| ctx = get_offline_ctx() | |
| return ctx | |
| def offline_scalar_ctx(): | |
| ctx = get_offline_ctx() | |
| ctx.train_output = {'[scalars]': 1} | |
| return ctx | |
| class MockOfflineWriter: | |
| def __init__(self): | |
| self.ctx = get_offline_ctx() | |
| def add_scalar(self, tag, scalar_value, global_step): | |
| assert global_step == self.ctx.train_iter | |
| if tag == 'basic/eval_episode_return_mean-train_iter': | |
| assert scalar_value == self.ctx.eval_value | |
| elif tag == 'basic/train_td_error-train_iter': | |
| assert scalar_value == self.ctx.train_output['td_error'] | |
| else: | |
| raise NotImplementedError('tag should be in the tags defined') | |
| def add_histogram(self, tag, values, global_step): | |
| assert tag == 'test_histogram' | |
| assert values == [1, 2, 3, 4, 5, 6] | |
| assert global_step == self.ctx.train_iter | |
| def close(self): | |
| pass | |
| def mock_get_offline_instance(): | |
| return MockOfflineWriter() | |
| class TestOfflineLogger: | |
| def test_offline_logger_no_scalars(self, offline_ctx_output_dict): | |
| with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): | |
| offline_logger()(offline_ctx_output_dict) | |
| def test_offline_logger_scalars(self, offline_scalar_ctx): | |
| with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance): | |
| with pytest.raises(NotImplementedError) as exc_info: | |
| offline_logger()(offline_scalar_ctx) | |
| class TheModelClass(nn.Module): | |
| def state_dict(self): | |
| return 'fake_state_dict' | |
| class TheEnvClass(Mock): | |
| def enable_save_replay(self, replay_path): | |
| return | |
| class TheObsDataClass(Mock): | |
| def __getitem__(self, index): | |
| return [[1, 1, 1]] * 50 | |
| class The1DDataClass(Mock): | |
| def __getitem__(self, index): | |
| return [[1]] * 50 | |
| def test_wandb_online_logger(): | |
| record_path = './video_qbert_dqn' | |
| cfg = EasyDict( | |
| dict( | |
| gradient_logger=True, | |
| plot_logger=True, | |
| action_logger=True, | |
| return_logger=True, | |
| video_logger=True, | |
| ) | |
| ) | |
| env = TheEnvClass() | |
| ctx = OnlineRLContext() | |
| ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] | |
| model = TheModelClass() | |
| wandb.init(config=cfg, anonymous="must") | |
| def mock_metric_logger(data, step): | |
| metric_list = [ | |
| "q_value", | |
| "target q_value", | |
| "loss", | |
| "lr", | |
| "entropy", | |
| "reward", | |
| "q value", | |
| "video", | |
| "q value distribution", | |
| "train iter", | |
| "episode return mean", | |
| "env step", | |
| "action", | |
| "actions_of_trajectory_0", | |
| "actions_of_trajectory_1", | |
| "actions_of_trajectory_2", | |
| "actions_of_trajectory_3", | |
| "return distribution", | |
| ] | |
| assert set(data.keys()) <= set(metric_list) | |
| def mock_gradient_logger(input_model, log, log_freq, log_graph): | |
| assert input_model == model | |
| def test_wandb_online_logger_metric(): | |
| with patch.object(wandb, 'log', new=mock_metric_logger): | |
| wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx) | |
| def test_wandb_online_logger_gradient(): | |
| with patch.object(wandb, 'watch', new=mock_gradient_logger): | |
| wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx) | |
| test_wandb_online_logger_metric() | |
| test_wandb_online_logger_gradient() | |
| def test_wandb_offline_logger(): | |
| record_path = './video_pendulum_cql' | |
| cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True)) | |
| env = TheEnvClass() | |
| ctx = OfflineRLContext() | |
| ctx.train_output = [{'reward': 1, 'q_value': [1.0]}] | |
| model = TheModelClass() | |
| wandb.init(config=cfg, anonymous="must") | |
| exp_config = EasyDict(dict(dataset_path='dataset.h5')) | |
| def mock_metric_logger(data, step=None): | |
| metric_list = [ | |
| "q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution", | |
| "train iter", 'dataset' | |
| ] | |
| assert set(data.keys()) < set(metric_list) | |
| def mock_gradient_logger(input_model, log, log_freq, log_graph): | |
| assert input_model == model | |
| def mock_image_logger(imagepath): | |
| assert os.path.splitext(imagepath)[-1] == '.png' | |
| def test_wandb_offline_logger_gradient(): | |
| cfg.vis_dataset = False | |
| print(cfg) | |
| with patch.object(wandb, 'watch', new=mock_gradient_logger): | |
| wandb_offline_logger( | |
| record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True | |
| )(ctx) | |
| def test_wandb_offline_logger_dataset(): | |
| cfg.vis_dataset = True | |
| m = MagicMock() | |
| m.__enter__.return_value = {'obs': TheObsDataClass(), 'action': The1DDataClass(), 'reward': The1DDataClass()} | |
| with patch.object(wandb, 'log', new=mock_metric_logger): | |
| with patch.object(wandb, 'Image', new=mock_image_logger): | |
| with patch('h5py.File', return_value=m): | |
| wandb_offline_logger( | |
| record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True | |
| )(ctx) | |
| test_wandb_offline_logger_gradient() | |
| test_wandb_offline_logger_dataset() | |