File size: 3,877 Bytes
0e2f05d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np

class Trainer:
    def __init__(self, env, eval_env, agent, args):
        self.args = args

        self.agent = agent
        self.env_name = args.env_name
        self.env = env
        self.eval_env = eval_env

        self.start_steps = args.start_steps
        self.max_steps = args.max_steps
        self.batch_size = args.batch_size
        self.target_noise_scale = args.target_noise_scale

        self.eval_flag = args.eval_flag
        self.eval_episode = args.eval_episode
        self.eval_freq = args.eval_freq

        self.episode = 0
        self.episode_reward = 0
        self.total_steps = 0
        self.eval_num = 0
        self.finish_flag = False

        self.target_noise_scale = args.target_noise_scale
        self.policy_update_delay = args.policy_update_delay

    def evaluate(self):
        # Evaluate process
        self.eval_num += 1
        reward_list = []

        for epi in range(self.eval_episode):
            epi_reward = 0
            state, _ = self.eval_env.reset()

            done = False

            while not done:
                action = self.agent.get_action(state, use_checkpoint=self.args.use_checkpoints, add_noise=False)
                next_state, reward, terminated, truncated, _ = self.eval_env.step(action)
                done = terminated or truncated
                epi_reward += reward
                state = next_state
            reward_list.append(epi_reward)

        print("Eval  |  total_step {}  |  episode {}  |  Average Reward {:.2f}  |  Max reward: {:.2f}  |  "
              "Min reward: {:.2f}".format(self.total_steps, self.episode, sum(reward_list)/len(reward_list),
                                               max(reward_list), min(reward_list), np.std(reward_list)))

    def run(self):
        # Train-process start.
        allow_train = False

        while not self.finish_flag:
            self.episode += 1
            done = False
            ep_total_reward, ep_timesteps = 0, 0

            state, _ = self.env.reset()
            # Episode start.
            while not done:
                self.total_steps += 1
                ep_timesteps += 1

                if allow_train:
                    action = self.agent.get_action(state, use_checkpoint=False, add_noise=True)
                else:
                    action = self.env.action_space.sample()
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                ep_total_reward += reward

                done_mask = 0.0 if ep_timesteps == self.env._max_episode_steps else float(done)
                self.agent.buffer.push(state, action, reward, next_state, done_mask)

                state = next_state

                if allow_train and not self.args.use_checkpoints:
                    actor_loss, critic_loss, encoder_loss = self.agent.train()
                    # Print loss.
                    if self.args.show_loss:
                        print("Loss  |  Actor loss {:.3f}  |  Critic loss {:.3f}  |  Encoder loss {:.3f}"
                              .format(actor_loss, critic_loss, encoder_loss))

                if done:
                    if allow_train and self.args.use_checkpoints:
                        self.agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)

                    if self.total_steps >= self.args.start_steps:
                        allow_train = True

                # Evaluation.
                if self.eval_flag and self.total_steps % self.eval_freq == 0:
                    self.evaluate()

                # Raise finish_flag.
                if self.total_steps == self.max_steps:
                    self.finish_flag = True