Spaces:
Build error
Build error
File size: 2,283 Bytes
a853d77 25a1345 a853d77 25a1345 a853d77 25a1345 a853d77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from warehouse_env import *
from stable_baselines3 import SAC
def train_func(alg_name='PPO'):
env = WarehouseEnv(render_mode='')
if alg_name == 'PPO':
# PPO
policy_kwargs = dict(activation_fn=torch.nn.ReLU,
net_arch=dict(pi=[64, 64], vf=[64, 64]))
model = PPO("MlpPolicy", env,
verbose=1,
policy_kwargs=policy_kwargs,
tensorboard_log="./ppo_tensorboard/",
# learning_rate=0.0003,
# clip_range=0.1,
)
model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv")
model.save("ppo_warehouse")
elif alg_name == 'SAC':
# policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300]))
# policy_kwargs = dict(net_arch=[512, 512]) # Two shared hidden layers
policy_kwargs = dict(net_arch=[32, 32]) # Two shared hidden layers
model = SAC("MlpPolicy", env, verbose=1,
tensorboard_log="./ppo_tensorboard/",
# learning_rate=0.0003,
policy_kwargs = policy_kwargs,
)
model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv")
model.save("sac_warehouse")
else:
raise RuntimeError('no model')
def exec_func(alg_name='SAC', model_name=None):
env = WarehouseEnv(render_mode='human')
if alg_name == 'PPO':
model_name = "ppo_warehouse" if model_name is None else model_name
model = PPO.load(model_name)
elif alg_name == 'SAC':
model_name = "sac_warehouse" if model_name is None else model_name
model = SAC.load(model_name)
else:
raise RuntimeError('no model')
# vec_env = model.get_env()
obs, info = env.reset()
while True:
action, _ = model.predict(obs)
obs, rewards, done, trunc, info = env.step(action)
env.render()
if done or trunc:
obs, info = env.reset()
def main():
# alg_name = 'PPO'
alg_name = 'SAC'
model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
# train_func(alg_name)
exec_func(alg_name=alg_name, model_name=model_name)
if __name__ == '__main__':
main() |