File size: 8,127 Bytes
1eefeba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import numpy as np

from utils import log_to_txt


class Trainer:
    def __init__(self, env, eval_env, agent, args):
        self.args = args
        self.agent = agent

        self.delayed_env = env
        self.eval_delayed_env = eval_env

        self.start_step = args.start_step
        self.update_after = args.update_after
        self.max_step = args.max_step
        self.batch_size = args.batch_size
        self.update_every = args.update_every

        self.eval_flag = args.eval_flag
        self.eval_episode = args.eval_episode
        self.eval_freq = args.eval_freq

        self.episode = 0
        self.total_step = 0
        self.local_step = 0
        self.eval_local_step = 0
        self.eval_num = 0
        self.finish_flag = False

        self.total_delayed_steps = args.obs_delayed_steps + self.args.act_delayed_steps

    def train(self):
        # The train process starts here.
        while not self.finish_flag:
            self.episode += 1
            self.local_step = 0

            # Initialize the delayed environment & the temporal buffer
            self.delayed_env.reset()
            self.agent.temporary_buffer.clear()
            done = False

            # Episode starts here.
            while not done:
                self.local_step += 1
                self.total_step += 1

                if self.local_step < self.total_delayed_steps:  # if t < d
                    action = np.zeros_like(self.delayed_env.action_space.sample())  # Select the 'no-op' action
                    _, _, _, _ = self.delayed_env.step(action)

                    self.agent.temporary_buffer.actions.append(action)
                elif self.local_step == self.total_delayed_steps:  # if t == d
                    if self.total_step < self.start_step:
                        action = self.delayed_env.action_space.sample()
                    else:
                        action = np.zeros_like(self.delayed_env.action_space.sample())  # Select the 'no-op' action

                    next_observed_state, _, _, _ = self.delayed_env.step(action)
                    #                s(1)       <-     Env: a(d)
                    self.agent.temporary_buffer.actions.append(action)  # Put a(d) to the temporary buffer
                    self.agent.temporary_buffer.states.append(next_observed_state)  # Put s(1) to the temporary buffer
                else:  # if t > d
                    last_observed_state = self.agent.temporary_buffer.states[-1]
                    first_action_idx = len(self.agent.temporary_buffer.actions) - self.total_delayed_steps

                    # Get the augmented state(t)
                    augmented_state = self.agent.temporary_buffer.get_augmented_state(last_observed_state, first_action_idx)

                    if self.total_step < self.start_step:
                        action = self.delayed_env.action_space.sample()
                    else:
                        action = self.agent.get_action(augmented_state, evaluation=False)
                        # a(t) <- policy: augmented_state(t)
                    next_observed_state, reward, done, info = self.delayed_env.step(action)
                    #          s(t+1-d),  r(t-d)      <-      Env: a(t)
                    true_done = 0.0 if self.local_step == self.delayed_env._max_episode_steps + self.args.obs_delayed_steps else float(done)

                    self.agent.temporary_buffer.actions.append(action)  # Put a(t) to the temporary buffer
                    self.agent.temporary_buffer.states.append(next_observed_state)  # Put s(t+1-d) to the temporary buffer

                    if self.local_step > 2 * self.total_delayed_steps:  # if t > 2d
                        augmented_s, s, a, next_augmented_s, next_s = self.agent.temporary_buffer.get_tuple()
                        #  aug_s(t-d),  s(t-d),  a(t-d),  aug_s(t+1-d),  s(t+1-d)  <- Temporal Buffer
                        self.agent.replay_memory.push(augmented_s, s, a, reward, next_augmented_s, next_s, true_done)
                        #  Store (aug_s(t-d), s(t-d), a(t-d), r(t-d), aug_s(t+1-d), s(t+1-d)) in the replay memory.

                # Update parameters
                if self.agent.replay_memory.size >= self.batch_size and self.total_step >= self.update_after and \
                        self.total_step % self.update_every == 0:
                    total_actor_loss = 0
                    total_critic_loss = 0
                    total_log_alpha_loss = 0
                    for i in range(self.update_every):
                        # Train the policy and the beta Q-network (critic).
                        critic_loss, actor_loss, log_alpha_loss = self.agent.train()
                        total_critic_loss += critic_loss
                        total_actor_loss += actor_loss
                        total_log_alpha_loss += log_alpha_loss

                    # Print the loss.
                    if self.args.show_loss:
                        print("Loss  |  Actor loss {:.3f}  |  Critic loss {:.3f}  |  Log-alpha loss {:.3f}"
                              .format(total_actor_loss / self.update_every, total_critic_loss / self.update_every,
                                      total_log_alpha_loss / self.update_every))

                # Evaluate.
                if self.eval_flag and self.total_step % self.eval_freq == 0:
                    self.evaluate()

                # Raise finish flag.
                if self.total_step == self.max_step:
                    self.finish_flag = True

    def evaluate(self):
        # Evaluate process
        self.eval_num += 1
        reward_list = []

        for epi in range(self.eval_episode):
            episode_reward = 0
            self.eval_delayed_env.reset()
            self.agent.eval_temporary_buffer.clear()
            done = False
            self.eval_local_step = 0

            while not done:
                self.eval_local_step += 1
                if self.eval_local_step < self.total_delayed_steps:
                    action = np.zeros_like(self.delayed_env.action_space.sample())
                    _, _, _, _ = self.eval_delayed_env.step(action)
                    self.agent.eval_temporary_buffer.actions.append(action)
                elif self.eval_local_step == self.total_delayed_steps:
                    action = np.zeros_like(self.eval_delayed_env.action_space.sample())
                    next_observed_state, _, _, _ = self.eval_delayed_env.step(action)
                    self.agent.eval_temporary_buffer.actions.append(action)
                    self.agent.eval_temporary_buffer.states.append(next_observed_state)
                else:
                    last_observed_state = self.agent.eval_temporary_buffer.states[-1]
                    first_action_idx = len(self.agent.eval_temporary_buffer.actions) - self.total_delayed_steps
                    augmented_state = self.agent.eval_temporary_buffer.get_augmented_state(last_observed_state,
                                                                                          first_action_idx)
                    action = self.agent.get_action(augmented_state, evaluation=True)
                    next_observed_state, reward, done, _ = self.eval_delayed_env.step(action)
                    self.agent.eval_temporary_buffer.actions.append(action)
                    self.agent.eval_temporary_buffer.states.append(next_observed_state)
                    episode_reward += reward

            reward_list.append(episode_reward)

        log_to_txt(self.args.env_name, self.args.random_seed, self.total_step, sum(reward_list) / len(reward_list))
        print("Eval  |  Total Steps {}  |  Episodes {}  |  Average Reward {:.2f}  |  Max reward {:.2f}  |  "
              "Min reward {:.2f}".format(self.total_step, self.episode, sum(reward_list) / len(reward_list),
                                          max(reward_list), min(reward_list)))