File size: 2,283 Bytes
a853d77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25a1345
a853d77
 
 
 
 
 
 
 
 
 
 
 
25a1345
a853d77
 
 
 
 
 
 
 
 
25a1345
a853d77
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from warehouse_env import *
from stable_baselines3 import SAC


def train_func(alg_name='PPO'):
    env = WarehouseEnv(render_mode='')

    if alg_name == 'PPO':
        # PPO
        policy_kwargs = dict(activation_fn=torch.nn.ReLU,
                             net_arch=dict(pi=[64, 64], vf=[64, 64]))
        model = PPO("MlpPolicy", env,
                    verbose=1,
                    policy_kwargs=policy_kwargs,
                    tensorboard_log="./ppo_tensorboard/",
                    # learning_rate=0.0003,
                    # clip_range=0.1,
                    )
        model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv")
        model.save("ppo_warehouse")

    elif alg_name == 'SAC':
        # policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300]))
        # policy_kwargs = dict(net_arch=[512, 512])  # Two shared hidden layers
        policy_kwargs = dict(net_arch=[32, 32])  # Two shared hidden layers
        model = SAC("MlpPolicy", env, verbose=1,
                    tensorboard_log="./ppo_tensorboard/",
                    # learning_rate=0.0003,
                    policy_kwargs = policy_kwargs,
                    )
        model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv")
        model.save("sac_warehouse")

    else:
        raise RuntimeError('no model')




def exec_func(alg_name='SAC', model_name=None):
    env = WarehouseEnv(render_mode='human')
    if alg_name == 'PPO':
        model_name = "ppo_warehouse" if model_name is None else model_name
        model = PPO.load(model_name)
    elif alg_name == 'SAC':
        model_name = "sac_warehouse" if model_name is None else model_name
        model = SAC.load(model_name)
    else:
        raise RuntimeError('no model')
    # vec_env = model.get_env()
    obs, info = env.reset()
    while True:
        action, _ = model.predict(obs)
        obs, rewards, done, trunc, info = env.step(action)
        env.render()
        if done or trunc:
            obs, info = env.reset()


def main():
    # alg_name = 'PPO'
    alg_name = 'SAC'
    model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
    # train_func(alg_name)
    exec_func(alg_name=alg_name, model_name=model_name)


if __name__ == '__main__':
    main()