UD7 / trainer.py
jangwon-kim-cocel's picture
Upload 10 files
0e2f05d verified
import numpy as np
class Trainer:
def __init__(self, env, eval_env, agent, args):
self.args = args
self.agent = agent
self.env_name = args.env_name
self.env = env
self.eval_env = eval_env
self.start_steps = args.start_steps
self.max_steps = args.max_steps
self.batch_size = args.batch_size
self.target_noise_scale = args.target_noise_scale
self.eval_flag = args.eval_flag
self.eval_episode = args.eval_episode
self.eval_freq = args.eval_freq
self.episode = 0
self.episode_reward = 0
self.total_steps = 0
self.eval_num = 0
self.finish_flag = False
self.target_noise_scale = args.target_noise_scale
self.policy_update_delay = args.policy_update_delay
def evaluate(self):
# Evaluate process
self.eval_num += 1
reward_list = []
for epi in range(self.eval_episode):
epi_reward = 0
state, _ = self.eval_env.reset()
done = False
while not done:
action = self.agent.get_action(state, use_checkpoint=self.args.use_checkpoints, add_noise=False)
next_state, reward, terminated, truncated, _ = self.eval_env.step(action)
done = terminated or truncated
epi_reward += reward
state = next_state
reward_list.append(epi_reward)
print("Eval | total_step {} | episode {} | Average Reward {:.2f} | Max reward: {:.2f} | "
"Min reward: {:.2f}".format(self.total_steps, self.episode, sum(reward_list)/len(reward_list),
max(reward_list), min(reward_list), np.std(reward_list)))
def run(self):
# Train-process start.
allow_train = False
while not self.finish_flag:
self.episode += 1
done = False
ep_total_reward, ep_timesteps = 0, 0
state, _ = self.env.reset()
# Episode start.
while not done:
self.total_steps += 1
ep_timesteps += 1
if allow_train:
action = self.agent.get_action(state, use_checkpoint=False, add_noise=True)
else:
action = self.env.action_space.sample()
next_state, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
ep_total_reward += reward
done_mask = 0.0 if ep_timesteps == self.env._max_episode_steps else float(done)
self.agent.buffer.push(state, action, reward, next_state, done_mask)
state = next_state
if allow_train and not self.args.use_checkpoints:
actor_loss, critic_loss, encoder_loss = self.agent.train()
# Print loss.
if self.args.show_loss:
print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Encoder loss {:.3f}"
.format(actor_loss, critic_loss, encoder_loss))
if done:
if allow_train and self.args.use_checkpoints:
self.agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)
if self.total_steps >= self.args.start_steps:
allow_train = True
# Evaluation.
if self.eval_flag and self.total_steps % self.eval_freq == 0:
self.evaluate()
# Raise finish_flag.
if self.total_steps == self.max_steps:
self.finish_flag = True