jangwon-kim-cocel commited on
Commit
1eefeba
·
verified ·
1 Parent(s): 4949b66

Upload 14 files

Browse files
Files changed (15) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. README.md +67 -3
  4. bpql.py +99 -0
  5. figures/neurips_logo.png +3 -0
  6. figures/plot.png +3 -0
  7. log/temp.md +1 -0
  8. main.py +56 -0
  9. network.py +119 -0
  10. replay_memory.py +44 -0
  11. run.sh +15 -0
  12. temporary_buffer.py +34 -0
  13. trainer.py +165 -0
  14. utils.py +82 -0
  15. wrapper.py +65 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/neurips_logo.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/plot.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 jangwonkim-cocel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,67 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Belief Projection-Based Q-Learning</h1>
3
+ <a href="https://www.python.org/">
4
+ <img src="https://img.shields.io/badge/Python-3.8-blue?logo=python&style=flat-square" alt="Python Badge"/>
5
+ </a>
6
+ &nbsp;&nbsp;
7
+ <a href="https://pytorch.org/">
8
+ <img src="https://img.shields.io/badge/PyTorch-2.1.0-%23EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
9
+ </a>
10
+ &nbsp;&nbsp;
11
+ <a href="https://proceedings.neurips.cc/paper_files/paper/2023/hash/0252a434b18962c94910c07cd9a7fecc-Abstract-Conference.html">
12
+ <img src="https://img.shields.io/badge/NeurIPS%202023-Paper-%23007ACC?style=flat-square" alt="NeurIPS 2023 Badge"/>
13
+ </a>
14
+ <br/><br/>
15
+ <img src="./figures/neurips_logo.png" width="200px" style="margin: 0 10px;"/>
16
+ </div>
17
+
18
+ ## [NeurIPS 2023] Official Implementation of Belief Projection-Based Q-Learning (BPQL)
19
+ This repository contains the PyTorch implementation of **BPQL** introduced in the paper: **_Belief Projection-Based Reinforcement Learning for Environments with Delayed Feedback_** by Jangwon Kim et al., presented at Advances in Neural Information Processing Systems (NeurIPS), 2023.
20
+
21
+
22
+ ## 📄 Paper Link
23
+ >You can see the paper here: https://proceedings.neurips.cc/paper_files/paper/2023/hash/0252a434b18962c94910c07cd9a7fecc-Abstract-Conference.html
24
+
25
+
26
+ ## 🚀 Achieves S.O.T.A. Performance, Yet Very Simple to Implement
27
+
28
+ * **Supports both observation delay, action delay, and their combination**
29
+ * **Performance Plot ⬇️**
30
+ <p align="center">
31
+ <img src="./figures/plot.png" alt="BPQL Performance Plot" width="600"/>
32
+ </p>
33
+
34
+
35
+ ## ▶️ How to Run?
36
+ ### Option 1: Run the script file
37
+ ```
38
+ >chmod +x run.sh
39
+ >./run.sh
40
+ ```
41
+
42
+ ### Option 2: Run main.py with arguments
43
+ ```
44
+ python main.py --env-name HalfCheetah-v3 --random-seed 2023 --obs-delayed-steps 5 --act-delayed-steps 4 --max-step 1000000
45
+ ```
46
+ ---
47
+
48
+ ## ✅Test Environment
49
+ ```
50
+ python == 3.8.10
51
+ gym == 0.26.2
52
+ mujoco_py == 2.1.2.14
53
+ pytorch == 2.1.0
54
+ numpy == 1.24.3
55
+ ```
56
+
57
+ ## 📚 Citation Example
58
+ ```
59
+ @inproceedings{kim2023cocel,
60
+ author = {Kim, Jangwon and Kim, Hangyeol and Kang, Jiwook and Baek, Jongchan and Han, Soohee},
61
+ booktitle = {Advances in Neural Information Processing Systems},
62
+ pages = {678--696},
63
+ title = {Belief Projection-Based Reinforcement Learning for Environments with Delayed Feedback},
64
+ volume = {36},
65
+ year = {2023}
66
+ }
67
+ ```
bpql.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from replay_memory import ReplayMemory
4
+ from network import Twin_Q_net, GaussianPolicy
5
+ from temporary_buffer import TemporaryBuffer
6
+ from utils import hard_update, soft_update
7
+
8
+
9
+ class BPQLAgent: # SAC for the base learning algorithm
10
+ def __init__(self, args, state_dim, action_dim, action_bound, action_space, device):
11
+ self.args = args
12
+
13
+ self.state_dim = state_dim
14
+ self.action_dim = action_dim
15
+ self.action_bound = action_bound
16
+
17
+ self.device = device
18
+ self.replay_memory = ReplayMemory(args.obs_delayed_steps + args.act_delayed_steps, state_dim, action_dim, device, args.buffer_size)
19
+ self.temporary_buffer = TemporaryBuffer(args.obs_delayed_steps + args.act_delayed_steps)
20
+ self.eval_temporary_buffer = TemporaryBuffer(args.obs_delayed_steps + args.act_delayed_steps)
21
+ self.batch_size = args.batch_size
22
+
23
+ self.gamma = args.gamma
24
+ self.tau = args.tau
25
+
26
+ self.actor = GaussianPolicy(args, args.obs_delayed_steps + args.act_delayed_steps, state_dim, action_dim, action_bound, args.hidden_dims, F.relu, device).to(device)
27
+ self.critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device) # Network for the beta Q-values.
28
+ self.target_critic = Twin_Q_net(state_dim, action_dim, device, args.hidden_dims).to(device)
29
+
30
+ self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr)
31
+ self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr)
32
+
33
+ # Automated Entropy Adjustment for Maximum Entropy RL
34
+ if args.automating_temperature is True:
35
+ self.target_entropy = -torch.prod(torch.Tensor(action_space.shape)).to(device)
36
+ self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
37
+ self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=args.temperature_lr)
38
+ else:
39
+ self.log_alpha = torch.log(torch.tensor(args.temperature, device=device, dtype=torch.float32))
40
+
41
+ hard_update(self.critic, self.target_critic)
42
+
43
+ def get_action(self, state, evaluation=True):
44
+ with torch.no_grad():
45
+ if evaluation:
46
+ _, _, action = self.actor.sample(state)
47
+ else:
48
+ action, _, _ = self.actor.sample(state)
49
+ return action.cpu().numpy()[0]
50
+
51
+ def train_actor(self, augmented_states, states, train_alpha=True):
52
+ self.actor_optimizer.zero_grad()
53
+ actions, log_pis, _ = self.actor.sample(augmented_states)
54
+ q_values_A, q_values_B = self.critic(states, actions)
55
+ q_values = torch.min(q_values_A, q_values_B)
56
+
57
+ actor_loss = (self.log_alpha.exp().detach() * log_pis - q_values).mean()
58
+ actor_loss.backward()
59
+ self.actor_optimizer.step()
60
+
61
+ if train_alpha:
62
+ self.alpha_optimizer.zero_grad()
63
+ alpha_loss = -(self.log_alpha.exp() * (log_pis + self.target_entropy).detach()).mean()
64
+ alpha_loss.backward()
65
+ self.alpha_optimizer.step()
66
+ else:
67
+ alpha_loss = torch.tensor(0.)
68
+
69
+ return actor_loss.item(), alpha_loss.item()
70
+
71
+ def train_critic(self, actions, rewards, next_augmented_states, dones, states, next_states):
72
+ self.critic_optimizer.zero_grad()
73
+ with torch.no_grad():
74
+ next_actions, next_log_pis, _ = self.actor.sample(next_augmented_states)
75
+ next_q_values_A, next_q_values_B = self.target_critic(next_states, next_actions)
76
+ next_q_values = torch.min(next_q_values_A, next_q_values_B) - self.log_alpha.exp() * next_log_pis
77
+ target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
78
+
79
+ q_values_A, q_values_B = self.critic(states, actions)
80
+ critic_loss = ((q_values_A - target_q_values)**2).mean() + ((q_values_B - target_q_values)**2).mean()
81
+
82
+ critic_loss.backward()
83
+ self.critic_optimizer.step()
84
+
85
+ return critic_loss.item() # 2 * Squared-Loss = (2*|TD-error|^2)
86
+
87
+ def train(self):
88
+ augmented_states, actions, rewards, next_augmented_states, dones, states, next_states = self.replay_memory.sample(self.batch_size)
89
+
90
+ critic_loss = self.train_critic(actions, rewards, next_augmented_states, dones, states, next_states)
91
+ if self.args.automating_temperature is True:
92
+ actor_loss, log_alpha_loss = self.train_actor(augmented_states, states, train_alpha=True)
93
+ else:
94
+ actor_loss, log_alpha_loss = self.train_actor(augmented_states, states, train_alpha=False)
95
+
96
+ soft_update(self.critic, self.target_critic, self.tau)
97
+
98
+ return critic_loss, actor_loss, log_alpha_loss
99
+
figures/neurips_logo.png ADDED

Git LFS Details

  • SHA256: 97ea0856f60827c33d8bf8b128b963313309939e929535f9b278d6470506b6e9
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB
figures/plot.png ADDED

Git LFS Details

  • SHA256: cc85738caa7a406778595a149bcd6d3bd21368d71115e97f79e3413a1a1a8605
  • Pointer size: 131 Bytes
  • Size of remote file: 261 kB
log/temp.md ADDED
@@ -0,0 +1 @@
 
 
1
+ temp.
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from bpql import BPQLAgent
4
+ from trainer import Trainer
5
+ from utils import set_seed, make_delayed_env
6
+
7
+ if __name__ == '__main__':
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument('--env-name', default='HalfCheetah-v3', type=str)
11
+
12
+ parser.add_argument('--obs-delayed-steps', default=4, type=int) # Delayed timesteps (Observation, Reward)
13
+ parser.add_argument('--act-delayed-steps', default=5, type=int) # Delayed timesteps (Action)
14
+
15
+ parser.add_argument('--random-seed', default=-1, type=int)
16
+ parser.add_argument('--eval_flag', default=True, type=bool)
17
+ parser.add_argument('--eval-freq', default=5000, type=int)
18
+ parser.add_argument('--eval-episode', default=5, type=int)
19
+ parser.add_argument('--automating-temperature', default=True, type=bool)
20
+ parser.add_argument('--temperature', default=0.2, type=float)
21
+ parser.add_argument('--start-step', default=10000, type=int)
22
+ parser.add_argument('--max-step', default=1000000, type=int)
23
+ parser.add_argument('--update_after', default=1000, type=int)
24
+ parser.add_argument('--hidden-dims', default=(256, 256))
25
+ parser.add_argument('--batch-size', default=256, type=int)
26
+ parser.add_argument('--buffer-size', default=1000000, type=int)
27
+ parser.add_argument('--update-every', default=50, type=int)
28
+ parser.add_argument('--log_std_bound', default=[-20, 2])
29
+ parser.add_argument('--gamma', default=0.99, type=float)
30
+ parser.add_argument('--actor-lr', default=3e-4, type=float)
31
+ parser.add_argument('--critic-lr', default=3e-4, type=float)
32
+ parser.add_argument('--temperature-lr', default=3e-4, type=float)
33
+ parser.add_argument('--tau', default=0.005, type=float)
34
+ parser.add_argument('--show-loss', default=False, type=bool)
35
+ args = parser.parse_args()
36
+
37
+ # Set Device
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # Set Seed
40
+ random_seed = set_seed(args.random_seed)
41
+
42
+ # Create Delayed Environment
43
+ env, eval_env = make_delayed_env(args, random_seed, obs_delayed_steps=args.obs_delayed_steps, act_delayed_steps=args.act_delayed_steps)
44
+
45
+ state_dim = env.observation_space.shape[0]
46
+ action_dim = env.action_space.shape[0]
47
+ action_bound = [env.action_space.low[0], env.action_space.high[0]]
48
+
49
+ print(f"Environment: {args.env_name}, Obs. Delayed Steps: {args.obs_delayed_steps}, Act. Delayed Steps: {args.act_delayed_steps}, Random Seed: {args.random_seed}", "\n")
50
+
51
+ # Create Agent
52
+ agent = BPQLAgent(args, state_dim, action_dim, action_bound, env.action_space, device)
53
+
54
+ # Create Trainer & Train
55
+ trainer = Trainer(env, eval_env, agent, args)
56
+ trainer.train()
network.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.distributions import Normal
5
+ from utils import weight_init
6
+
7
+
8
+ class Twin_Q_net(nn.Module):
9
+ def __init__(self, state_dim, action_dim, device, hidden_dims=(256, 256), activation_fc=F.relu):
10
+ super(Twin_Q_net, self).__init__()
11
+ self.device = device
12
+
13
+ self.activation_fc = activation_fc
14
+
15
+ self.input_layer_A = nn.Linear(state_dim + action_dim, hidden_dims[0])
16
+ self.hidden_layers_A = nn.ModuleList()
17
+ for i in range(len(hidden_dims)-1):
18
+ hidden_layer_A = nn.Linear(hidden_dims[i], hidden_dims[i+1])
19
+ self.hidden_layers_A.append(hidden_layer_A)
20
+ self.output_layer_A = nn.Linear(hidden_dims[-1], 1)
21
+
22
+ self.input_layer_B = nn.Linear(state_dim + action_dim, hidden_dims[0])
23
+ self.hidden_layers_B = nn.ModuleList()
24
+ for i in range(len(hidden_dims)-1):
25
+ hidden_layer_B = nn.Linear(hidden_dims[i], hidden_dims[i+1])
26
+ self.hidden_layers_B.append(hidden_layer_B)
27
+ self.output_layer_B = nn.Linear(hidden_dims[-1], 1)
28
+ self.apply(weight_init)
29
+
30
+ def _format(self, state, action):
31
+ x, u = state, action
32
+
33
+ if not isinstance(x, torch.Tensor):
34
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
35
+ x = x.unsqueeze(0)
36
+
37
+ if not isinstance(u, torch.Tensor):
38
+ u = torch.tensor(u, device=self.device, dtype=torch.float32)
39
+ u = u.unsqueeze(0)
40
+
41
+ return x, u
42
+
43
+ def forward(self, state, action):
44
+ x, u = self._format(state, action)
45
+ x = torch.cat([x, u], dim=1)
46
+
47
+ x_A = self.activation_fc(self.input_layer_A(x))
48
+ for i, hidden_layer_A in enumerate(self.hidden_layers_A):
49
+ x_A = self.activation_fc(hidden_layer_A(x_A))
50
+ x_A = self.output_layer_A(x_A)
51
+
52
+ x_B = self.activation_fc(self.input_layer_B(x))
53
+ for i, hidden_layer_B in enumerate(self.hidden_layers_B):
54
+ x_B = self.activation_fc(hidden_layer_B(x_B))
55
+ x_B = self.output_layer_B(x_B)
56
+
57
+ return x_A, x_B
58
+
59
+
60
+ class GaussianPolicy(nn.Module):
61
+ def __init__(self, args, delayed_steps, state_dim, action_dim, action_bound,
62
+ hidden_dims=(256, 256), activation_fc=F.relu, device='cuda'):
63
+ super(GaussianPolicy, self).__init__()
64
+ self.device = device
65
+
66
+ self.log_std_min = args.log_std_bound[0]
67
+ self.log_std_max = args.log_std_bound[1]
68
+
69
+ self.activation_fc = activation_fc
70
+
71
+ self.input_layer = nn.Linear(state_dim + delayed_steps * action_dim, hidden_dims[0])
72
+ self.hidden_layers = nn.ModuleList()
73
+ for i in range(len(hidden_dims)-1):
74
+ hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
75
+ self.hidden_layers.append(hidden_layer)
76
+
77
+ self.mean_layer = nn.Linear(hidden_dims[-1], action_dim)
78
+ self.log_std_layer = nn.Linear(hidden_dims[-1], action_dim)
79
+
80
+ self.action_rescale = torch.as_tensor((action_bound[1] - action_bound[0]) / 2., dtype=torch.float32)
81
+ self.action_rescale_bias = torch.as_tensor((action_bound[1] + action_bound[0]) / 2., dtype=torch.float32)
82
+
83
+ self.apply(weight_init)
84
+
85
+ def _format(self, state):
86
+ x = state
87
+ if not isinstance(x, torch.Tensor):
88
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
89
+ x = x.unsqueeze(0)
90
+ return x
91
+
92
+ def forward(self, state):
93
+ x = self._format(state)
94
+ x = self.activation_fc(self.input_layer(x))
95
+ for i, hidden_layer in enumerate(self.hidden_layers):
96
+ x = self.activation_fc(hidden_layer(x))
97
+ mean = self.mean_layer(x)
98
+ log_std = self.log_std_layer(x)
99
+ log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
100
+ return mean, log_std
101
+
102
+ def sample(self, state):
103
+ mean, log_std = self.forward(state)
104
+ distribution = Normal(mean, log_std.exp())
105
+
106
+ unbounded_action = distribution.rsample()
107
+
108
+ bounded_action = torch.tanh(unbounded_action)
109
+ action = bounded_action * self.action_rescale + self.action_rescale_bias
110
+
111
+ log_prob = distribution.log_prob(unbounded_action) - torch.log(self.action_rescale *
112
+ (1 - bounded_action.pow(2).clamp(0, 1)) + 1e-6)
113
+ log_prob = log_prob.sum(dim=1, keepdim=True)
114
+ mean = torch.tanh(mean) * self.action_rescale + self.action_rescale_bias
115
+ return action, log_prob, mean
116
+
117
+
118
+
119
+
replay_memory.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class ReplayMemory:
6
+ def __init__(self, delayed_steps, state_dim, action_dim, device, capacity=1e6):
7
+ self.device = device
8
+ self.capacity = int(capacity)
9
+ self.size = 0
10
+ self.position = 0
11
+
12
+ self.augmented_state_buffer = np.empty(shape=(self.capacity, state_dim + action_dim * delayed_steps), dtype=np.float32)
13
+ self.action_buffer = np.empty(shape=(self.capacity, action_dim), dtype=np.float32)
14
+ self.reward_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
15
+ self.next_augmented_state_buffer = np.empty(shape=(self.capacity, state_dim + action_dim * delayed_steps), dtype=np.float32)
16
+ self.done_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
17
+ self.state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
18
+ self.next_state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
19
+
20
+ def push(self, augmented_state, state, action, reward, next_augmented_state, next_state, done):
21
+ self.size = min(self.size + 1, self.capacity)
22
+
23
+ self.augmented_state_buffer[self.position] = augmented_state
24
+ self.action_buffer[self.position] = action
25
+ self.reward_buffer[self.position] = reward
26
+ self.next_augmented_state_buffer[self.position] = next_augmented_state
27
+ self.done_buffer[self.position] = done
28
+ self.state_buffer[self.position] = state
29
+ self.next_state_buffer[self.position] = next_state
30
+
31
+ self.position = (self.position + 1) % self.capacity
32
+
33
+ def sample(self, batch_size):
34
+ idxs = np.random.randint(0, self.size, size=batch_size)
35
+
36
+ augmented_states = torch.FloatTensor(self.augmented_state_buffer[idxs]).to(self.device)
37
+ actions = torch.FloatTensor(self.action_buffer[idxs]).to(self.device)
38
+ rewards = torch.FloatTensor(self.reward_buffer[idxs]).to(self.device)
39
+ next_augmented_states = torch.FloatTensor(self.next_augmented_state_buffer[idxs]).to(self.device)
40
+ dones = torch.FloatTensor(self.done_buffer[idxs]).to(self.device)
41
+ states = torch.FloatTensor(self.state_buffer[idxs]).to(self.device)
42
+ next_states = torch.FloatTensor(self.next_state_buffer[idxs]).to(self.device)
43
+
44
+ return augmented_states, actions, rewards, next_augmented_states, dones, states, next_states
run.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----Test Environment----
2
+ # python == 3.8.10
3
+ # gym == 0.26.2
4
+ # mujoco_py == 2.1.2.14
5
+ # pytorch == 2.1.0
6
+ # numpy == 1.24.3
7
+ #-------------------------
8
+
9
+ #!/bin/bash
10
+ python main.py \
11
+ --env-name "HalfCheetah-v3" \
12
+ --random-seed 2023 \
13
+ --obs-delayed-steps 5 \
14
+ --act-delayed-steps 4 \
15
+ --max-step 1000000
temporary_buffer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import deque
3
+
4
+
5
+ class TemporaryBuffer:
6
+ def __init__(self, delayed_steps):
7
+ self.d = delayed_steps
8
+ self.states = deque(maxlen=delayed_steps + 2)
9
+ self.actions = deque(maxlen=2 * delayed_steps + 1)
10
+
11
+ def clear(self):
12
+ self.states.clear()
13
+ self.actions.clear()
14
+
15
+ def get_augmented_state(self, last_observed_state, first_action_idx):
16
+ aug_state = np.concatenate([last_observed_state, self.actions[first_action_idx]])
17
+ for i in range(first_action_idx + 1, first_action_idx + self.d):
18
+ aug_state = np.concatenate([aug_state, self.actions[i]])
19
+ return aug_state
20
+
21
+ def get_tuple(self):
22
+ assert len(self.states) == self.d + 2 and len(self.actions) == 2 * self.d + 1
23
+
24
+ aug_s = self.get_augmented_state(self.states[0], 0)
25
+ s = self.states[-2]
26
+ a = self.actions[self.d]
27
+
28
+ next_aug_s = self.get_augmented_state(self.states[1], 1)
29
+ next_s = self.states[-1]
30
+
31
+ self.states.popleft()
32
+ self.actions.popleft()
33
+ return aug_s, s, a, next_aug_s, next_s
34
+
trainer.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from utils import log_to_txt
4
+
5
+
6
+ class Trainer:
7
+ def __init__(self, env, eval_env, agent, args):
8
+ self.args = args
9
+ self.agent = agent
10
+
11
+ self.delayed_env = env
12
+ self.eval_delayed_env = eval_env
13
+
14
+ self.start_step = args.start_step
15
+ self.update_after = args.update_after
16
+ self.max_step = args.max_step
17
+ self.batch_size = args.batch_size
18
+ self.update_every = args.update_every
19
+
20
+ self.eval_flag = args.eval_flag
21
+ self.eval_episode = args.eval_episode
22
+ self.eval_freq = args.eval_freq
23
+
24
+ self.episode = 0
25
+ self.total_step = 0
26
+ self.local_step = 0
27
+ self.eval_local_step = 0
28
+ self.eval_num = 0
29
+ self.finish_flag = False
30
+
31
+ self.total_delayed_steps = args.obs_delayed_steps + self.args.act_delayed_steps
32
+
33
+ def train(self):
34
+ # The train process starts here.
35
+ while not self.finish_flag:
36
+ self.episode += 1
37
+ self.local_step = 0
38
+
39
+ # Initialize the delayed environment & the temporal buffer
40
+ self.delayed_env.reset()
41
+ self.agent.temporary_buffer.clear()
42
+ done = False
43
+
44
+ # Episode starts here.
45
+ while not done:
46
+ self.local_step += 1
47
+ self.total_step += 1
48
+
49
+ if self.local_step < self.total_delayed_steps: # if t < d
50
+ action = np.zeros_like(self.delayed_env.action_space.sample()) # Select the 'no-op' action
51
+ _, _, _, _ = self.delayed_env.step(action)
52
+
53
+ self.agent.temporary_buffer.actions.append(action)
54
+ elif self.local_step == self.total_delayed_steps: # if t == d
55
+ if self.total_step < self.start_step:
56
+ action = self.delayed_env.action_space.sample()
57
+ else:
58
+ action = np.zeros_like(self.delayed_env.action_space.sample()) # Select the 'no-op' action
59
+
60
+ next_observed_state, _, _, _ = self.delayed_env.step(action)
61
+ # s(1) <- Env: a(d)
62
+ self.agent.temporary_buffer.actions.append(action) # Put a(d) to the temporary buffer
63
+ self.agent.temporary_buffer.states.append(next_observed_state) # Put s(1) to the temporary buffer
64
+ else: # if t > d
65
+ last_observed_state = self.agent.temporary_buffer.states[-1]
66
+ first_action_idx = len(self.agent.temporary_buffer.actions) - self.total_delayed_steps
67
+
68
+ # Get the augmented state(t)
69
+ augmented_state = self.agent.temporary_buffer.get_augmented_state(last_observed_state, first_action_idx)
70
+
71
+ if self.total_step < self.start_step:
72
+ action = self.delayed_env.action_space.sample()
73
+ else:
74
+ action = self.agent.get_action(augmented_state, evaluation=False)
75
+ # a(t) <- policy: augmented_state(t)
76
+ next_observed_state, reward, done, info = self.delayed_env.step(action)
77
+ # s(t+1-d), r(t-d) <- Env: a(t)
78
+ true_done = 0.0 if self.local_step == self.delayed_env._max_episode_steps + self.args.obs_delayed_steps else float(done)
79
+
80
+ self.agent.temporary_buffer.actions.append(action) # Put a(t) to the temporary buffer
81
+ self.agent.temporary_buffer.states.append(next_observed_state) # Put s(t+1-d) to the temporary buffer
82
+
83
+ if self.local_step > 2 * self.total_delayed_steps: # if t > 2d
84
+ augmented_s, s, a, next_augmented_s, next_s = self.agent.temporary_buffer.get_tuple()
85
+ # aug_s(t-d), s(t-d), a(t-d), aug_s(t+1-d), s(t+1-d) <- Temporal Buffer
86
+ self.agent.replay_memory.push(augmented_s, s, a, reward, next_augmented_s, next_s, true_done)
87
+ # Store (aug_s(t-d), s(t-d), a(t-d), r(t-d), aug_s(t+1-d), s(t+1-d)) in the replay memory.
88
+
89
+ # Update parameters
90
+ if self.agent.replay_memory.size >= self.batch_size and self.total_step >= self.update_after and \
91
+ self.total_step % self.update_every == 0:
92
+ total_actor_loss = 0
93
+ total_critic_loss = 0
94
+ total_log_alpha_loss = 0
95
+ for i in range(self.update_every):
96
+ # Train the policy and the beta Q-network (critic).
97
+ critic_loss, actor_loss, log_alpha_loss = self.agent.train()
98
+ total_critic_loss += critic_loss
99
+ total_actor_loss += actor_loss
100
+ total_log_alpha_loss += log_alpha_loss
101
+
102
+ # Print the loss.
103
+ if self.args.show_loss:
104
+ print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Log-alpha loss {:.3f}"
105
+ .format(total_actor_loss / self.update_every, total_critic_loss / self.update_every,
106
+ total_log_alpha_loss / self.update_every))
107
+
108
+ # Evaluate.
109
+ if self.eval_flag and self.total_step % self.eval_freq == 0:
110
+ self.evaluate()
111
+
112
+ # Raise finish flag.
113
+ if self.total_step == self.max_step:
114
+ self.finish_flag = True
115
+
116
+ def evaluate(self):
117
+ # Evaluate process
118
+ self.eval_num += 1
119
+ reward_list = []
120
+
121
+ for epi in range(self.eval_episode):
122
+ episode_reward = 0
123
+ self.eval_delayed_env.reset()
124
+ self.agent.eval_temporary_buffer.clear()
125
+ done = False
126
+ self.eval_local_step = 0
127
+
128
+ while not done:
129
+ self.eval_local_step += 1
130
+ if self.eval_local_step < self.total_delayed_steps:
131
+ action = np.zeros_like(self.delayed_env.action_space.sample())
132
+ _, _, _, _ = self.eval_delayed_env.step(action)
133
+ self.agent.eval_temporary_buffer.actions.append(action)
134
+ elif self.eval_local_step == self.total_delayed_steps:
135
+ action = np.zeros_like(self.eval_delayed_env.action_space.sample())
136
+ next_observed_state, _, _, _ = self.eval_delayed_env.step(action)
137
+ self.agent.eval_temporary_buffer.actions.append(action)
138
+ self.agent.eval_temporary_buffer.states.append(next_observed_state)
139
+ else:
140
+ last_observed_state = self.agent.eval_temporary_buffer.states[-1]
141
+ first_action_idx = len(self.agent.eval_temporary_buffer.actions) - self.total_delayed_steps
142
+ augmented_state = self.agent.eval_temporary_buffer.get_augmented_state(last_observed_state,
143
+ first_action_idx)
144
+ action = self.agent.get_action(augmented_state, evaluation=True)
145
+ next_observed_state, reward, done, _ = self.eval_delayed_env.step(action)
146
+ self.agent.eval_temporary_buffer.actions.append(action)
147
+ self.agent.eval_temporary_buffer.states.append(next_observed_state)
148
+ episode_reward += reward
149
+
150
+ reward_list.append(episode_reward)
151
+
152
+ log_to_txt(self.args.env_name, self.args.random_seed, self.total_step, sum(reward_list) / len(reward_list))
153
+ print("Eval | Total Steps {} | Episodes {} | Average Reward {:.2f} | Max reward {:.2f} | "
154
+ "Min reward {:.2f}".format(self.total_step, self.episode, sum(reward_list) / len(reward_list),
155
+ max(reward_list), min(reward_list)))
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+ from wrapper import DelayedEnv
6
+
7
+
8
+ def weight_init(m):
9
+ """Custom weight init for Conv2D and Linear layers.
10
+ Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
11
+
12
+ if isinstance(m, nn.Linear):
13
+ nn.init.orthogonal_(m.weight.data)
14
+ m.bias.data.fill_(0.0)
15
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
16
+ # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
17
+ assert m.weight.size(2) == m.weight.size(3)
18
+ m.weight.data.fill_(0.0)
19
+ m.bias.data.fill_(0.0)
20
+ mid = m.weight.size(2) // 2
21
+ gain = nn.init.calculate_gain('relu')
22
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
23
+
24
+
25
+ def hard_update(network, target_network):
26
+ with torch.no_grad():
27
+ for param, target_param in zip(network.parameters(), target_network.parameters()):
28
+ target_param.data.copy_(param.data)
29
+
30
+
31
+ def soft_update(network, target_network, tau):
32
+ with torch.no_grad():
33
+ for param, target_param in zip(network.parameters(), target_network.parameters()):
34
+ target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
35
+
36
+
37
+ def set_seed(random_seed):
38
+ if random_seed <= 0:
39
+ random_seed = np.random.randint(1, 9999)
40
+ else:
41
+ random_seed = random_seed
42
+
43
+ torch.manual_seed(random_seed)
44
+ np.random.seed(random_seed)
45
+ random.seed(random_seed)
46
+
47
+ return random_seed
48
+
49
+
50
+ def make_env(env_name, random_seed):
51
+ import gym
52
+ # openai gym
53
+ env = gym.make(env_name)
54
+ env.seed(random_seed)
55
+ env.action_space.seed(random_seed)
56
+
57
+ eval_env = gym.make(env_name)
58
+ eval_env.seed(random_seed)
59
+ eval_env.action_space.seed(random_seed)
60
+
61
+ return env, eval_env
62
+
63
+
64
+ def make_delayed_env(args, random_seed, obs_delayed_steps, act_delayed_steps):
65
+ import gym
66
+ # openai gym
67
+ env_name = args.env_name
68
+
69
+ env = gym.make(env_name)
70
+ delayed_env = DelayedEnv(env, seed=random_seed, obs_delayed_steps=obs_delayed_steps, act_delayed_steps=act_delayed_steps)
71
+
72
+ eval_env = gym.make(env_name)
73
+ eval_delayed_env = DelayedEnv(eval_env, seed=random_seed, obs_delayed_steps=obs_delayed_steps, act_delayed_steps=act_delayed_steps)
74
+
75
+ return delayed_env, eval_delayed_env
76
+
77
+ def log_to_txt(env_name, random_seed, total_step, result):
78
+ seed = '(' + str(random_seed) + ')'
79
+ f = open('./log/' + env_name + '_seed' + seed + '.txt', 'a')
80
+ log = str(total_step) + ' ' + str(result) + '\n'
81
+ f.write(log)
82
+ f.close()
wrapper.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import gym
3
+ import numpy as np
4
+
5
+
6
+ class DelayedEnv(gym.Wrapper):
7
+ def __init__(self, env, seed, obs_delayed_steps, act_delayed_steps):
8
+ super(DelayedEnv, self).__init__(env)
9
+ assert obs_delayed_steps + act_delayed_steps > 0
10
+ self.env.action_space.seed(seed)
11
+
12
+ self.observation_space = self.env.observation_space
13
+ self.action_space = self.env.action_space
14
+
15
+ self._max_episode_steps = self.env._max_episode_steps
16
+
17
+ self.obs_buffer = deque(maxlen=obs_delayed_steps)
18
+ self.reward_buffer = deque(maxlen=obs_delayed_steps)
19
+ self.done_buffer = deque(maxlen=obs_delayed_steps)
20
+
21
+ self.action_buffer = deque(maxlen=act_delayed_steps)
22
+
23
+ self.obs_delayed_steps = obs_delayed_steps
24
+ self.act_delayed_steps = act_delayed_steps
25
+
26
+ def reset(self):
27
+ for _ in range(self.act_delayed_steps):
28
+ self.action_buffer.append(np.zeros_like(self.env.action_space.sample()))
29
+
30
+ init_state, _ = self.env.reset()
31
+ for _ in range(self.obs_delayed_steps):
32
+ self.obs_buffer.append(init_state)
33
+ self.reward_buffer.append(0)
34
+ self.done_buffer.append(False)
35
+ return init_state
36
+
37
+ def step(self, action):
38
+ if self.act_delayed_steps > 0:
39
+ delayed_action = self.action_buffer.popleft()
40
+ self.action_buffer.append(action)
41
+ else:
42
+ delayed_action = action
43
+
44
+ current_obs, current_reward, current_terminated, current_truncated, _ = self.env.step(delayed_action)
45
+ current_done = current_terminated or current_truncated
46
+
47
+ if self.obs_delayed_steps > 0:
48
+ delayed_obs = self.obs_buffer.popleft()
49
+ delayed_reward = self.reward_buffer.popleft()
50
+ delayed_done = self.done_buffer.popleft()
51
+
52
+ self.obs_buffer.append(current_obs)
53
+ self.reward_buffer.append(current_reward)
54
+ self.done_buffer.append(current_done)
55
+ else:
56
+ delayed_obs = current_obs
57
+ delayed_reward = current_reward
58
+ delayed_done = current_done
59
+
60
+ return delayed_obs, delayed_reward, delayed_done, {'current_obs': current_obs, 'current_reward': current_reward,
61
+ 'current_done': current_done}
62
+
63
+
64
+
65
+