Spaces:
Running
Running
| import pytest | |
| from easydict import EasyDict | |
| from ding.framework import OnlineRLContext | |
| from ding.framework.middleware.ckpt_handler import CkptSaver | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import os | |
| import shutil | |
| from unittest.mock import Mock, patch | |
| from ding.framework import task | |
| from ding.policy.base_policy import Policy | |
| class TheModelClass(nn.Module): | |
| def state_dict(self): | |
| return 'fake_state_dict' | |
| class MockPolicy(Mock): | |
| def __init__(self, model, **kwargs) -> None: | |
| super(MockPolicy, self).__init__(model) | |
| self.learn_mode = model | |
| def eval_mode(self): | |
| return EasyDict({"state_dict": lambda: {}}) | |
| def test_ckpt_saver(): | |
| exp_name = 'test_ckpt_saver_exp' | |
| ctx = OnlineRLContext() | |
| train_freq = 100 | |
| model = TheModelClass() | |
| if not os.path.exists(exp_name): | |
| os.makedirs(exp_name) | |
| prefix = '{}/ckpt'.format(exp_name) | |
| with patch("ding.policy.Policy", MockPolicy), task.start(): | |
| policy = MockPolicy(model) | |
| def mock_save_file(path, data, fs_type=None, use_lock=False): | |
| assert path == "{}/eval.pth.tar".format(prefix) | |
| with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
| ctx.train_iter = 1 | |
| ctx.eval_value = 9.4 | |
| ckpt_saver = CkptSaver(policy, exp_name, train_freq) | |
| ckpt_saver(ctx) | |
| def mock_save_file(path, data, fs_type=None, use_lock=False): | |
| assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter) | |
| with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
| ctx.train_iter = 100 | |
| ctx.eval_value = 1 | |
| ckpt_saver(ctx) | |
| def mock_save_file(path, data, fs_type=None, use_lock=False): | |
| assert path == "{}/final.pth.tar".format(prefix) | |
| with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file): | |
| task.finish = True | |
| ckpt_saver(ctx) | |
| shutil.rmtree(exp_name) | |