Spaces:
Sleeping
Sleeping
| import pytest | |
| import time | |
| import os | |
| import torch | |
| import subprocess | |
| from copy import deepcopy | |
| from ding.entry import serial_pipeline, serial_pipeline_offline, collect_demo_data, serial_pipeline_onpolicy | |
| from ding.entry.serial_entry_sqil import serial_pipeline_sqil | |
| from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config # noqa | |
| from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main | |
| from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main | |
| from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa | |
| from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config | |
| from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config | |
| from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config | |
| from dizoo.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config | |
| from dizoo.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main | |
| from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config | |
| from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main | |
| from dizoo.league_demo.league_demo_ppo_main import main as league_main | |
| from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import pendulum_sac_data_genearation_config, pendulum_sac_data_genearation_create_config # noqa | |
| from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa | |
| from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa | |
| from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa | |
| from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa | |
| from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config | |
| with open("./algo_record.log", "w+") as f: | |
| f.write("ALGO TEST STARTS\n") | |
| def test_dqn(): | |
| config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("1. dqn\n") | |
| def test_ddpg(): | |
| config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("2. ddpg\n") | |
| def test_td3(): | |
| config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("3. td3\n") | |
| def test_a2c(): | |
| config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)] | |
| try: | |
| serial_pipeline_onpolicy(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("4. a2c\n") | |
| def test_rainbow(): | |
| config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("5. rainbow\n") | |
| def test_ppo(): | |
| config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] | |
| try: | |
| ppo_main(config[0], seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("6. ppo\n") | |
| # @pytest.mark.algotest | |
| def test_collaq(): | |
| config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("7. collaq\n") | |
| # @pytest.mark.algotest | |
| def test_coma(): | |
| config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("8. coma\n") | |
| def test_sac(): | |
| config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("9. sac\n") | |
| def test_c51(): | |
| config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("10. c51\n") | |
| def test_r2d2(): | |
| config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("11. r2d2\n") | |
| def test_pg(): | |
| config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)] | |
| try: | |
| serial_pipeline_onpolicy(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("12. pg\n") | |
| # @pytest.mark.algotest | |
| def test_atoc(): | |
| config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("13. atoc\n") | |
| # @pytest.mark.algotest | |
| def test_vdn(): | |
| config = [deepcopy(ptz_simple_spread_vdn_config), deepcopy(ptz_simple_spread_vdn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("14. vdn\n") | |
| # @pytest.mark.algotest | |
| def test_qmix(): | |
| config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("15. qmix\n") | |
| def test_impala(): | |
| config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("16. impala\n") | |
| def test_iqn(): | |
| config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("17. iqn\n") | |
| def test_her_dqn(): | |
| try: | |
| bitflip_her_dqn_config.exp_name = 'bitflip5_dqn' | |
| bitflip_her_dqn_config.env.n_bits = 5 | |
| bitflip_her_dqn_config.policy.model.obs_shape = 10 | |
| bitflip_her_dqn_config.policy.model.action_shape = 5 | |
| bitflip_dqn_main(bitflip_her_dqn_config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("18. her dqn\n") | |
| def test_ppg(): | |
| try: | |
| ppg_main(cartpole_ppg_config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("19. ppg\n") | |
| def test_sqn(): | |
| config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("20. sqn\n") | |
| def test_qrdqn(): | |
| config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("21. qrdqn\n") | |
| def test_acer(): | |
| config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("22. acer\n") | |
| def test_selfplay(): | |
| try: | |
| selfplay_main(deepcopy(league_demo_ppo_config), seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("23. selfplay\n") | |
| def test_league(): | |
| try: | |
| league_main(deepcopy(league_demo_ppo_config), seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("24. league\n") | |
| def test_sqil(): | |
| expert_policy_state_dict_path = './expert_policy.pth' | |
| config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)] | |
| expert_policy = serial_pipeline(config, seed=0) | |
| torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) | |
| config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)] | |
| config[0].policy.collect.model_path = expert_policy_state_dict_path | |
| try: | |
| serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("25. sqil\n") | |
| def test_cql(): | |
| # train expert | |
| config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)] | |
| config[0].exp_name = 'sac' | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # collect expert data | |
| import torch | |
| config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)] | |
| collect_count = config[0].policy.collect.n_sample | |
| expert_data_path = config[0].policy.collect.save_path | |
| state_dict = torch.load('./sac/ckpt/ckpt_best.pth.tar', map_location='cpu') | |
| try: | |
| collect_demo_data( | |
| config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict | |
| ) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # train cql | |
| config = [deepcopy(pendulum_cql_config), deepcopy(pendulum_cql_create_config)] | |
| try: | |
| serial_pipeline_offline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("26. cql\n") | |
| def test_discrete_cql(): | |
| # train expert | |
| config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] | |
| config[0].exp_name = 'cartpole' | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # collect expert data | |
| import torch | |
| config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] | |
| collect_count = config[0].policy.collect.collect_count | |
| state_dict = torch.load('cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu') | |
| try: | |
| collect_demo_data(config, seed=0, collect_count=collect_count, state_dict=state_dict) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # train cql | |
| config = [deepcopy(cartpole_discrete_cql_config), deepcopy(cartpole_discrete_cql_create_config)] | |
| try: | |
| serial_pipeline_offline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("27. discrete cql\n") | |
| # @pytest.mark.algotest | |
| def test_wqmix(): | |
| config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("28. wqmix\n") | |
| def test_mdqn(): | |
| config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)] | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("29. mdqn\n") | |
| # @pytest.mark.algotest | |
| def test_td3_bc(): | |
| # train expert | |
| config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)] | |
| config[0].exp_name = 'td3' | |
| try: | |
| serial_pipeline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # collect expert data | |
| import torch | |
| config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)] | |
| collect_count = config[0].policy.other.replay_buffer.replay_buffer_size | |
| expert_data_path = config[0].policy.collect.save_path | |
| state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu') | |
| try: | |
| collect_demo_data( | |
| config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict | |
| ) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| # train td3 bc | |
| config = [deepcopy(pendulum_td3_bc_config), deepcopy(pendulum_td3_bc_create_config)] | |
| try: | |
| serial_pipeline_offline(config, seed=0) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| with open("./algo_record.log", "a+") as f: | |
| f.write("29. td3_bc\n") | |
| # @pytest.mark.algotest | |
| def test_running_on_orchestrator(): | |
| from kubernetes import config, client, dynamic | |
| from ding.utils import K8sLauncher, OrchestratorLauncher | |
| cluster_name = 'test-k8s-launcher' | |
| config_path = os.path.join(os.path.dirname(__file__), 'config', 'k8s-config.yaml') | |
| # create cluster | |
| launcher = K8sLauncher(config_path) | |
| launcher.name = cluster_name | |
| launcher.create_cluster() | |
| # create orchestrator | |
| olauncher = OrchestratorLauncher('v0.2.0-rc.0', cluster=launcher) | |
| olauncher.create_orchestrator() | |
| # create dijob | |
| namespace = 'default' | |
| name = 'cartpole-dqn' | |
| timeout = 20 * 60 | |
| file_path = os.path.dirname(__file__) | |
| agconfig_path = os.path.join(file_path, 'config', 'agconfig.yaml') | |
| dijob_path = os.path.join(file_path, 'config', 'dijob-cartpole.yaml') | |
| create_object_from_config(agconfig_path, 'di-system') | |
| create_object_from_config(dijob_path, namespace) | |
| # watch for dijob to converge | |
| config.load_kube_config() | |
| dyclient = dynamic.DynamicClient(client.ApiClient(configuration=config.load_kube_config())) | |
| dijobapi = dyclient.resources.get(api_version='diengine.opendilab.org/v1alpha1', kind='DIJob') | |
| wait_for_dijob_condition(dijobapi, name, namespace, 'Succeeded', timeout) | |
| v1 = client.CoreV1Api() | |
| logs = v1.read_namespaced_pod_log(f'{name}-coordinator', namespace, tail_lines=20) | |
| print(f'\ncoordinator logs:\n {logs} \n') | |
| # delete dijob | |
| dijobapi.delete(name=name, namespace=namespace, body={}) | |
| # delete orchestrator | |
| olauncher.delete_orchestrator() | |
| # delete k8s cluster | |
| launcher.delete_cluster() | |
| def create_object_from_config(config_path: str, namespace: str = 'default'): | |
| args = ['kubectl', 'apply', '-n', namespace, '-f', config_path] | |
| proc = subprocess.Popen(args, stderr=subprocess.PIPE) | |
| _, err = proc.communicate() | |
| err_str = err.decode('utf-8').strip() | |
| if err_str != '' and 'WARN' not in err_str and 'already exists' not in err_str: | |
| raise RuntimeError(f'Failed to create object: {err_str}') | |
| def delete_object_from_config(config_path: str, namespace: str = 'default'): | |
| args = ['kubectl', 'delete', '-n', namespace, '-f', config_path] | |
| proc = subprocess.Popen(args, stderr=subprocess.PIPE) | |
| _, err = proc.communicate() | |
| err_str = err.decode('utf-8').strip() | |
| if err_str != '' and 'WARN' not in err_str and 'NotFound' not in err_str: | |
| raise RuntimeError(f'Failed to delete object: {err_str}') | |
| def wait_for_dijob_condition(dijobapi, name: str, namespace: str, phase: str, timeout: int = 60, interval: int = 1): | |
| start = time.time() | |
| dijob = dijobapi.get(name=name, namespace=namespace) | |
| while (dijob.status is None or dijob.status.phase != phase) and time.time() - start < timeout: | |
| time.sleep(interval) | |
| dijob = dijobapi.get(name=name, namespace=namespace) | |
| if dijob.status.phase == phase: | |
| return | |
| raise TimeoutError(f'Timeout waiting for DIJob: {name} to be {phase}') | |