| import os |
| import unittest |
| import json |
| from enum import Enum |
| import time |
| from mlagents.trainers.training_status import ( |
| StatusType, |
| StatusMetaData, |
| GlobalTrainingStatus, |
| ) |
| from mlagents.trainers.policy.checkpoint_manager import ( |
| ModelCheckpointManager, |
| ModelCheckpoint, |
| ) |
|
|
|
|
| def test_globaltrainingstatus(tmpdir): |
| path_dir = os.path.join(tmpdir, "test.json") |
|
|
| GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3) |
| GlobalTrainingStatus.save_state(path_dir) |
|
|
| with open(path_dir) as fp: |
| test_json = json.load(fp) |
|
|
| assert "Category1" in test_json |
| assert StatusType.LESSON_NUM.value in test_json["Category1"] |
| assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3 |
| assert "metadata" in test_json |
|
|
| GlobalTrainingStatus.load_state(path_dir) |
| restored_val = GlobalTrainingStatus.get_parameter_state( |
| "Category1", StatusType.LESSON_NUM |
| ) |
| assert restored_val == 3 |
|
|
| |
| unknown_category = GlobalTrainingStatus.get_parameter_state( |
| "Category3", StatusType.LESSON_NUM |
| ) |
|
|
| class FakeStatusType(Enum): |
| NOTAREALKEY = "notarealkey" |
|
|
| unknown_key = GlobalTrainingStatus.get_parameter_state( |
| "Category1", FakeStatusType.NOTAREALKEY |
| ) |
| assert unknown_category is None |
| assert unknown_key is None |
|
|
|
|
| def test_model_management(tmpdir): |
|
|
| results_path = os.path.join(tmpdir, "results") |
| brain_name = "Mock_brain" |
| final_model_path = os.path.join(results_path, brain_name) |
| test_checkpoint_list = [ |
| { |
| "steps": 1, |
| "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"), |
| "reward": 1.312, |
| "creation_time": time.time(), |
| "auxillary_file_paths": [], |
| }, |
| { |
| "steps": 2, |
| "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"), |
| "reward": 1.912, |
| "creation_time": time.time(), |
| "auxillary_file_paths": [], |
| }, |
| { |
| "steps": 3, |
| "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"), |
| "reward": 2.312, |
| "creation_time": time.time(), |
| "auxillary_file_paths": [], |
| }, |
| ] |
| GlobalTrainingStatus.set_parameter_state( |
| brain_name, StatusType.CHECKPOINTS, test_checkpoint_list |
| ) |
|
|
| new_checkpoint_4 = ModelCheckpoint( |
| 4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time() |
| ) |
| ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4) |
| assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 |
|
|
| new_checkpoint_5 = ModelCheckpoint( |
| 5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time() |
| ) |
| ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4) |
| assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 |
|
|
| final_model_path = f"{final_model_path}.nn" |
| final_model_time = time.time() |
| current_step = 6 |
| final_model = ModelCheckpoint( |
| current_step, final_model_path, 3.294, final_model_time |
| ) |
|
|
| ModelCheckpointManager.track_final_checkpoint(brain_name, final_model) |
| assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 |
|
|
| check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][ |
| StatusType.CHECKPOINTS.value |
| ] |
| assert check_checkpoints is not None |
|
|
| final_model = GlobalTrainingStatus.saved_state[StatusType.FINAL_CHECKPOINT.value] |
| assert final_model is not None |
|
|
|
|
| class StatsMetaDataTest(unittest.TestCase): |
| def test_metadata_compare(self): |
| |
| with self.assertLogs("mlagents.trainers", level="WARNING") as cm: |
| default_metadata = StatusMetaData() |
| version_statsmetadata = StatusMetaData(mlagents_version="test") |
| default_metadata.check_compatibility(version_statsmetadata) |
|
|
| torch_version_statsmetadata = StatusMetaData(torch_version="test") |
| default_metadata.check_compatibility(torch_version_statsmetadata) |
|
|
| |
| assert len(cm.output) == 2 |
|
|