File size: 8,127 Bytes
1eefeba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import numpy as np
from utils import log_to_txt
class Trainer:
def __init__(self, env, eval_env, agent, args):
self.args = args
self.agent = agent
self.delayed_env = env
self.eval_delayed_env = eval_env
self.start_step = args.start_step
self.update_after = args.update_after
self.max_step = args.max_step
self.batch_size = args.batch_size
self.update_every = args.update_every
self.eval_flag = args.eval_flag
self.eval_episode = args.eval_episode
self.eval_freq = args.eval_freq
self.episode = 0
self.total_step = 0
self.local_step = 0
self.eval_local_step = 0
self.eval_num = 0
self.finish_flag = False
self.total_delayed_steps = args.obs_delayed_steps + self.args.act_delayed_steps
def train(self):
# The train process starts here.
while not self.finish_flag:
self.episode += 1
self.local_step = 0
# Initialize the delayed environment & the temporal buffer
self.delayed_env.reset()
self.agent.temporary_buffer.clear()
done = False
# Episode starts here.
while not done:
self.local_step += 1
self.total_step += 1
if self.local_step < self.total_delayed_steps: # if t < d
action = np.zeros_like(self.delayed_env.action_space.sample()) # Select the 'no-op' action
_, _, _, _ = self.delayed_env.step(action)
self.agent.temporary_buffer.actions.append(action)
elif self.local_step == self.total_delayed_steps: # if t == d
if self.total_step < self.start_step:
action = self.delayed_env.action_space.sample()
else:
action = np.zeros_like(self.delayed_env.action_space.sample()) # Select the 'no-op' action
next_observed_state, _, _, _ = self.delayed_env.step(action)
# s(1) <- Env: a(d)
self.agent.temporary_buffer.actions.append(action) # Put a(d) to the temporary buffer
self.agent.temporary_buffer.states.append(next_observed_state) # Put s(1) to the temporary buffer
else: # if t > d
last_observed_state = self.agent.temporary_buffer.states[-1]
first_action_idx = len(self.agent.temporary_buffer.actions) - self.total_delayed_steps
# Get the augmented state(t)
augmented_state = self.agent.temporary_buffer.get_augmented_state(last_observed_state, first_action_idx)
if self.total_step < self.start_step:
action = self.delayed_env.action_space.sample()
else:
action = self.agent.get_action(augmented_state, evaluation=False)
# a(t) <- policy: augmented_state(t)
next_observed_state, reward, done, info = self.delayed_env.step(action)
# s(t+1-d), r(t-d) <- Env: a(t)
true_done = 0.0 if self.local_step == self.delayed_env._max_episode_steps + self.args.obs_delayed_steps else float(done)
self.agent.temporary_buffer.actions.append(action) # Put a(t) to the temporary buffer
self.agent.temporary_buffer.states.append(next_observed_state) # Put s(t+1-d) to the temporary buffer
if self.local_step > 2 * self.total_delayed_steps: # if t > 2d
augmented_s, s, a, next_augmented_s, next_s = self.agent.temporary_buffer.get_tuple()
# aug_s(t-d), s(t-d), a(t-d), aug_s(t+1-d), s(t+1-d) <- Temporal Buffer
self.agent.replay_memory.push(augmented_s, s, a, reward, next_augmented_s, next_s, true_done)
# Store (aug_s(t-d), s(t-d), a(t-d), r(t-d), aug_s(t+1-d), s(t+1-d)) in the replay memory.
# Update parameters
if self.agent.replay_memory.size >= self.batch_size and self.total_step >= self.update_after and \
self.total_step % self.update_every == 0:
total_actor_loss = 0
total_critic_loss = 0
total_log_alpha_loss = 0
for i in range(self.update_every):
# Train the policy and the beta Q-network (critic).
critic_loss, actor_loss, log_alpha_loss = self.agent.train()
total_critic_loss += critic_loss
total_actor_loss += actor_loss
total_log_alpha_loss += log_alpha_loss
# Print the loss.
if self.args.show_loss:
print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Log-alpha loss {:.3f}"
.format(total_actor_loss / self.update_every, total_critic_loss / self.update_every,
total_log_alpha_loss / self.update_every))
# Evaluate.
if self.eval_flag and self.total_step % self.eval_freq == 0:
self.evaluate()
# Raise finish flag.
if self.total_step == self.max_step:
self.finish_flag = True
def evaluate(self):
# Evaluate process
self.eval_num += 1
reward_list = []
for epi in range(self.eval_episode):
episode_reward = 0
self.eval_delayed_env.reset()
self.agent.eval_temporary_buffer.clear()
done = False
self.eval_local_step = 0
while not done:
self.eval_local_step += 1
if self.eval_local_step < self.total_delayed_steps:
action = np.zeros_like(self.delayed_env.action_space.sample())
_, _, _, _ = self.eval_delayed_env.step(action)
self.agent.eval_temporary_buffer.actions.append(action)
elif self.eval_local_step == self.total_delayed_steps:
action = np.zeros_like(self.eval_delayed_env.action_space.sample())
next_observed_state, _, _, _ = self.eval_delayed_env.step(action)
self.agent.eval_temporary_buffer.actions.append(action)
self.agent.eval_temporary_buffer.states.append(next_observed_state)
else:
last_observed_state = self.agent.eval_temporary_buffer.states[-1]
first_action_idx = len(self.agent.eval_temporary_buffer.actions) - self.total_delayed_steps
augmented_state = self.agent.eval_temporary_buffer.get_augmented_state(last_observed_state,
first_action_idx)
action = self.agent.get_action(augmented_state, evaluation=True)
next_observed_state, reward, done, _ = self.eval_delayed_env.step(action)
self.agent.eval_temporary_buffer.actions.append(action)
self.agent.eval_temporary_buffer.states.append(next_observed_state)
episode_reward += reward
reward_list.append(episode_reward)
log_to_txt(self.args.env_name, self.args.random_seed, self.total_step, sum(reward_list) / len(reward_list))
print("Eval | Total Steps {} | Episodes {} | Average Reward {:.2f} | Max reward {:.2f} | "
"Min reward {:.2f}".format(self.total_step, self.episode, sum(reward_list) / len(reward_list),
max(reward_list), min(reward_list)))
|