|
|
import numpy as np
|
|
|
|
|
|
from utils import log_to_txt
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
def __init__(self, env, eval_env, agent, args):
|
|
|
self.args = args
|
|
|
self.agent = agent
|
|
|
|
|
|
self.delayed_env = env
|
|
|
self.eval_delayed_env = eval_env
|
|
|
|
|
|
self.start_step = args.start_step
|
|
|
self.update_after = args.update_after
|
|
|
self.max_step = args.max_step
|
|
|
self.batch_size = args.batch_size
|
|
|
self.update_every = args.update_every
|
|
|
|
|
|
self.eval_flag = args.eval_flag
|
|
|
self.eval_episode = args.eval_episode
|
|
|
self.eval_freq = args.eval_freq
|
|
|
|
|
|
self.episode = 0
|
|
|
self.total_step = 0
|
|
|
self.local_step = 0
|
|
|
self.eval_local_step = 0
|
|
|
self.eval_num = 0
|
|
|
self.finish_flag = False
|
|
|
|
|
|
self.total_delayed_steps = args.obs_delayed_steps + self.args.act_delayed_steps
|
|
|
|
|
|
def train(self):
|
|
|
|
|
|
while not self.finish_flag:
|
|
|
self.episode += 1
|
|
|
self.local_step = 0
|
|
|
|
|
|
|
|
|
self.delayed_env.reset()
|
|
|
self.agent.temporary_buffer.clear()
|
|
|
done = False
|
|
|
|
|
|
|
|
|
while not done:
|
|
|
self.local_step += 1
|
|
|
self.total_step += 1
|
|
|
|
|
|
if self.local_step < self.total_delayed_steps:
|
|
|
action = np.zeros_like(self.delayed_env.action_space.sample())
|
|
|
_, _, _, _ = self.delayed_env.step(action)
|
|
|
|
|
|
self.agent.temporary_buffer.actions.append(action)
|
|
|
elif self.local_step == self.total_delayed_steps:
|
|
|
if self.total_step < self.start_step:
|
|
|
action = self.delayed_env.action_space.sample()
|
|
|
else:
|
|
|
action = np.zeros_like(self.delayed_env.action_space.sample())
|
|
|
|
|
|
next_observed_state, _, _, _ = self.delayed_env.step(action)
|
|
|
|
|
|
self.agent.temporary_buffer.actions.append(action)
|
|
|
self.agent.temporary_buffer.states.append(next_observed_state)
|
|
|
else:
|
|
|
last_observed_state = self.agent.temporary_buffer.states[-1]
|
|
|
first_action_idx = len(self.agent.temporary_buffer.actions) - self.total_delayed_steps
|
|
|
|
|
|
|
|
|
augmented_state = self.agent.temporary_buffer.get_augmented_state(last_observed_state, first_action_idx)
|
|
|
|
|
|
if self.total_step < self.start_step:
|
|
|
action = self.delayed_env.action_space.sample()
|
|
|
else:
|
|
|
action = self.agent.get_action(augmented_state, evaluation=False)
|
|
|
|
|
|
next_observed_state, reward, done, info = self.delayed_env.step(action)
|
|
|
|
|
|
true_done = 0.0 if self.local_step == self.delayed_env._max_episode_steps + self.args.obs_delayed_steps else float(done)
|
|
|
|
|
|
self.agent.temporary_buffer.actions.append(action)
|
|
|
self.agent.temporary_buffer.states.append(next_observed_state)
|
|
|
|
|
|
if self.local_step > 2 * self.total_delayed_steps:
|
|
|
augmented_s, s, a, next_augmented_s, next_s = self.agent.temporary_buffer.get_tuple()
|
|
|
|
|
|
self.agent.replay_memory.push(augmented_s, s, a, reward, next_augmented_s, next_s, true_done)
|
|
|
|
|
|
|
|
|
|
|
|
if self.agent.replay_memory.size >= self.batch_size and self.total_step >= self.update_after and \
|
|
|
self.total_step % self.update_every == 0:
|
|
|
total_actor_loss = 0
|
|
|
total_critic_loss = 0
|
|
|
total_log_alpha_loss = 0
|
|
|
for i in range(self.update_every):
|
|
|
|
|
|
critic_loss, actor_loss, log_alpha_loss = self.agent.train()
|
|
|
total_critic_loss += critic_loss
|
|
|
total_actor_loss += actor_loss
|
|
|
total_log_alpha_loss += log_alpha_loss
|
|
|
|
|
|
|
|
|
if self.args.show_loss:
|
|
|
print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Log-alpha loss {:.3f}"
|
|
|
.format(total_actor_loss / self.update_every, total_critic_loss / self.update_every,
|
|
|
total_log_alpha_loss / self.update_every))
|
|
|
|
|
|
|
|
|
if self.eval_flag and self.total_step % self.eval_freq == 0:
|
|
|
self.evaluate()
|
|
|
|
|
|
|
|
|
if self.total_step == self.max_step:
|
|
|
self.finish_flag = True
|
|
|
|
|
|
def evaluate(self):
|
|
|
|
|
|
self.eval_num += 1
|
|
|
reward_list = []
|
|
|
|
|
|
for epi in range(self.eval_episode):
|
|
|
episode_reward = 0
|
|
|
self.eval_delayed_env.reset()
|
|
|
self.agent.eval_temporary_buffer.clear()
|
|
|
done = False
|
|
|
self.eval_local_step = 0
|
|
|
|
|
|
while not done:
|
|
|
self.eval_local_step += 1
|
|
|
if self.eval_local_step < self.total_delayed_steps:
|
|
|
action = np.zeros_like(self.delayed_env.action_space.sample())
|
|
|
_, _, _, _ = self.eval_delayed_env.step(action)
|
|
|
self.agent.eval_temporary_buffer.actions.append(action)
|
|
|
elif self.eval_local_step == self.total_delayed_steps:
|
|
|
action = np.zeros_like(self.eval_delayed_env.action_space.sample())
|
|
|
next_observed_state, _, _, _ = self.eval_delayed_env.step(action)
|
|
|
self.agent.eval_temporary_buffer.actions.append(action)
|
|
|
self.agent.eval_temporary_buffer.states.append(next_observed_state)
|
|
|
else:
|
|
|
last_observed_state = self.agent.eval_temporary_buffer.states[-1]
|
|
|
first_action_idx = len(self.agent.eval_temporary_buffer.actions) - self.total_delayed_steps
|
|
|
augmented_state = self.agent.eval_temporary_buffer.get_augmented_state(last_observed_state,
|
|
|
first_action_idx)
|
|
|
action = self.agent.get_action(augmented_state, evaluation=True)
|
|
|
next_observed_state, reward, done, _ = self.eval_delayed_env.step(action)
|
|
|
self.agent.eval_temporary_buffer.actions.append(action)
|
|
|
self.agent.eval_temporary_buffer.states.append(next_observed_state)
|
|
|
episode_reward += reward
|
|
|
|
|
|
reward_list.append(episode_reward)
|
|
|
|
|
|
log_to_txt(self.args.env_name, self.args.random_seed, self.total_step, sum(reward_list) / len(reward_list))
|
|
|
print("Eval | Total Steps {} | Episodes {} | Average Reward {:.2f} | Max reward {:.2f} | "
|
|
|
"Min reward {:.2f}".format(self.total_step, self.episode, sum(reward_list) / len(reward_list),
|
|
|
max(reward_list), min(reward_list)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|