Spaces:
Running
Running
| from easydict import EasyDict | |
| from functools import partial | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| from ding.envs import BaseEnvManager, SyncSubprocessEnvManager | |
| from ding.config import compile_config | |
| from ding.model.template import VAC | |
| from ding.policy import PPOPolicy | |
| from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner | |
| from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv | |
| from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper | |
| # Load the trained model from this direction, if None, it will initialize from scratch | |
| model_dir = None | |
| metadrive_basic_config = dict( | |
| exp_name='metadrive_onppo_eval_seed0', | |
| env=dict( | |
| metadrive=dict( | |
| use_render=True, | |
| traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1] | |
| map='XSOS', # Int or string: an easy way to fill map_config | |
| horizon=4000, # Max step number | |
| driving_reward=1.0, # Reward to encourage agent to move forward. | |
| speed_reward=0.10, # Reward to encourage agent to drive at a high speed | |
| use_lateral_reward=False, # reward for lane keeping | |
| out_of_road_penalty=40.0, # Penalty to discourage driving out of road | |
| crash_vehicle_penalty=40.0, # Penalty to discourage collision | |
| decision_repeat=20, # Reciprocal of decision frequency | |
| out_of_route_done=True, # Game over if driving out of road | |
| show_bird_view=False, # Only used to evaluate, whether to draw five channels of bird-view image | |
| ), | |
| manager=dict( | |
| shared_memory=False, | |
| max_retry=2, | |
| context='spawn', | |
| ), | |
| n_evaluator_episode=16, | |
| stop_value=255, | |
| collector_env_num=1, | |
| evaluator_env_num=1, | |
| ), | |
| policy=dict( | |
| cuda=True, | |
| action_space='continuous', | |
| model=dict( | |
| obs_shape=[5, 84, 84], | |
| action_shape=2, | |
| action_space='continuous', | |
| bound_type='tanh', | |
| encoder_hidden_size_list=[128, 128, 64], | |
| ), | |
| learn=dict( | |
| epoch_per_collect=10, | |
| batch_size=64, | |
| learning_rate=3e-4, | |
| entropy_weight=0.001, | |
| value_weight=0.5, | |
| clip_ratio=0.02, | |
| adv_norm=False, | |
| value_norm=True, | |
| grad_clip_value=10, | |
| ), | |
| collect=dict(n_sample=1000, ), | |
| eval=dict(evaluator=dict(eval_freq=1000, ), ), | |
| ), | |
| ) | |
| main_config = EasyDict(metadrive_basic_config) | |
| def wrapped_env(env_cfg, wrapper_cfg=None): | |
| return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg) | |
| def main(cfg): | |
| cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator) | |
| evaluator_env_num = cfg.env.evaluator_env_num | |
| show_bird_view = cfg.env.metadrive.show_bird_view | |
| wrapper_cfg = {'show_bird_view': show_bird_view} | |
| evaluator_env = BaseEnvManager( | |
| env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)], | |
| cfg=cfg.env.manager, | |
| ) | |
| model = VAC(**cfg.policy.model) | |
| policy = PPOPolicy(cfg.policy, model=model) | |
| if model_dir is not None: | |
| policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu')) | |
| tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name)) | |
| evaluator = InteractionSerialEvaluator( | |
| cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | |
| ) | |
| stop, rate = evaluator.eval() | |
| evaluator.close() | |
| if __name__ == '__main__': | |
| main(main_config) | |