Agent_Control_with_Language / train_agent.py
ArseniyPerchik's picture
more
25a1345
from warehouse_env import *
from stable_baselines3 import SAC
def train_func(alg_name='PPO'):
env = WarehouseEnv(render_mode='')
if alg_name == 'PPO':
# PPO
policy_kwargs = dict(activation_fn=torch.nn.ReLU,
net_arch=dict(pi=[64, 64], vf=[64, 64]))
model = PPO("MlpPolicy", env,
verbose=1,
policy_kwargs=policy_kwargs,
tensorboard_log="./ppo_tensorboard/",
# learning_rate=0.0003,
# clip_range=0.1,
)
model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv")
model.save("ppo_warehouse")
elif alg_name == 'SAC':
# policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300]))
# policy_kwargs = dict(net_arch=[512, 512]) # Two shared hidden layers
policy_kwargs = dict(net_arch=[32, 32]) # Two shared hidden layers
model = SAC("MlpPolicy", env, verbose=1,
tensorboard_log="./ppo_tensorboard/",
# learning_rate=0.0003,
policy_kwargs = policy_kwargs,
)
model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv")
model.save("sac_warehouse")
else:
raise RuntimeError('no model')
def exec_func(alg_name='SAC', model_name=None):
env = WarehouseEnv(render_mode='human')
if alg_name == 'PPO':
model_name = "ppo_warehouse" if model_name is None else model_name
model = PPO.load(model_name)
elif alg_name == 'SAC':
model_name = "sac_warehouse" if model_name is None else model_name
model = SAC.load(model_name)
else:
raise RuntimeError('no model')
# vec_env = model.get_env()
obs, info = env.reset()
while True:
action, _ = model.predict(obs)
obs, rewards, done, trunc, info = env.step(action)
env.render()
if done or trunc:
obs, info = env.reset()
def main():
# alg_name = 'PPO'
alg_name = 'SAC'
model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
# train_func(alg_name)
exec_func(alg_name=alg_name, model_name=model_name)
if __name__ == '__main__':
main()