Spaces:
Running
Running
| import gym | |
| import torch | |
| from easydict import EasyDict | |
| from ding.config import compile_config | |
| from ding.envs import DingEnvWrapper | |
| from ding.model import DQN | |
| from ding.policy import DQNPolicy, single_env_forward_wrapper | |
| from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config | |
| from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv | |
| def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str): | |
| main_config.exp_name = f'cliffwalking_dqn_seed0_deploy' | |
| cfg = compile_config(main_config, create_cfg=create_config, auto=True) | |
| env = CliffWalkingEnv(cfg.env) | |
| env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video') | |
| model = DQN(**cfg.policy.model) | |
| state_dict = torch.load(ckpt_path, map_location='cpu') | |
| model.load_state_dict(state_dict['model']) | |
| policy = DQNPolicy(cfg.policy, model=model).eval_mode | |
| forward_fn = single_env_forward_wrapper(policy.forward) | |
| obs = env.reset() | |
| returns = 0. | |
| while True: | |
| action = forward_fn(obs) | |
| obs, rew, done, info = env.step(action) | |
| returns += rew | |
| if done: | |
| break | |
| print(f'Deploy is finished, final epsiode return is: {returns}') | |
| if __name__ == "__main__": | |
| main( | |
| main_config=main_config, | |
| create_config=create_config, | |
| ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar' | |
| ) | |