from warehouse_env import * from stable_baselines3 import SAC def train_func(alg_name='PPO'): env = WarehouseEnv(render_mode='') if alg_name == 'PPO': # PPO policy_kwargs = dict(activation_fn=torch.nn.ReLU, net_arch=dict(pi=[64, 64], vf=[64, 64])) model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs, tensorboard_log="./ppo_tensorboard/", # learning_rate=0.0003, # clip_range=0.1, ) model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv") model.save("ppo_warehouse") elif alg_name == 'SAC': # policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300])) # policy_kwargs = dict(net_arch=[512, 512]) # Two shared hidden layers policy_kwargs = dict(net_arch=[32, 32]) # Two shared hidden layers model = SAC("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_tensorboard/", # learning_rate=0.0003, policy_kwargs = policy_kwargs, ) model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv") model.save("sac_warehouse") else: raise RuntimeError('no model') def exec_func(alg_name='SAC', model_name=None): env = WarehouseEnv(render_mode='human') if alg_name == 'PPO': model_name = "ppo_warehouse" if model_name is None else model_name model = PPO.load(model_name) elif alg_name == 'SAC': model_name = "sac_warehouse" if model_name is None else model_name model = SAC.load(model_name) else: raise RuntimeError('no model') # vec_env = model.get_env() obs, info = env.reset() while True: action, _ = model.predict(obs) obs, rewards, done, trunc, info = env.step(action) env.render() if done or trunc: obs, info = env.reset() def main(): # alg_name = 'PPO' alg_name = 'SAC' model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip' # train_func(alg_name) exec_func(alg_name=alg_name, model_name=model_name) if __name__ == '__main__': main()