Spaces:
Build error
Build error
| from warehouse_env import * | |
| from stable_baselines3 import SAC | |
| def train_func(alg_name='PPO'): | |
| env = WarehouseEnv(render_mode='') | |
| if alg_name == 'PPO': | |
| # PPO | |
| policy_kwargs = dict(activation_fn=torch.nn.ReLU, | |
| net_arch=dict(pi=[64, 64], vf=[64, 64])) | |
| model = PPO("MlpPolicy", env, | |
| verbose=1, | |
| policy_kwargs=policy_kwargs, | |
| tensorboard_log="./ppo_tensorboard/", | |
| # learning_rate=0.0003, | |
| # clip_range=0.1, | |
| ) | |
| model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv") | |
| model.save("ppo_warehouse") | |
| elif alg_name == 'SAC': | |
| # policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300])) | |
| # policy_kwargs = dict(net_arch=[512, 512]) # Two shared hidden layers | |
| policy_kwargs = dict(net_arch=[32, 32]) # Two shared hidden layers | |
| model = SAC("MlpPolicy", env, verbose=1, | |
| tensorboard_log="./ppo_tensorboard/", | |
| # learning_rate=0.0003, | |
| policy_kwargs = policy_kwargs, | |
| ) | |
| model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv") | |
| model.save("sac_warehouse") | |
| else: | |
| raise RuntimeError('no model') | |
| def exec_func(alg_name='SAC', model_name=None): | |
| env = WarehouseEnv(render_mode='human') | |
| if alg_name == 'PPO': | |
| model_name = "ppo_warehouse" if model_name is None else model_name | |
| model = PPO.load(model_name) | |
| elif alg_name == 'SAC': | |
| model_name = "sac_warehouse" if model_name is None else model_name | |
| model = SAC.load(model_name) | |
| else: | |
| raise RuntimeError('no model') | |
| # vec_env = model.get_env() | |
| obs, info = env.reset() | |
| while True: | |
| action, _ = model.predict(obs) | |
| obs, rewards, done, trunc, info = env.step(action) | |
| env.render() | |
| if done or trunc: | |
| obs, info = env.reset() | |
| def main(): | |
| # alg_name = 'PPO' | |
| alg_name = 'SAC' | |
| model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip' | |
| # train_func(alg_name) | |
| exec_func(alg_name=alg_name, model_name=model_name) | |
| if __name__ == '__main__': | |
| main() |