jangwon-kim-cocel commited on
Commit
0e2f05d
·
verified ·
1 Parent(s): ad18363

Upload 10 files

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -0
  2. README.md +176 -0
  3. figures/performance.png +3 -0
  4. figures/runtime.png +0 -0
  5. figures/tmp.md +1 -0
  6. main.py +78 -0
  7. network.py +145 -0
  8. replay_memory.py +67 -0
  9. trainer.py +112 -0
  10. ud7.py +202 -0
  11. utils.py +63 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/performance.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div align="center">
3
+ <h1>UD7</h1>
4
+ <h3>Provable Generalization of Clipped Double Q-Learning for Variance Reduction and Sample Efficiency</h3>
5
+
6
+ <a href="https://www.python.org/">
7
+ <img src="https://img.shields.io/badge/Python-3.7+-blue?logo=python&style=flat-square" alt="Python Badge"/>
8
+ </a>
9
+ &nbsp;&nbsp;
10
+ <a href="https://pytorch.org/">
11
+ <img src="https://img.shields.io/badge/PyTorch-1.8+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
12
+ </a>
13
+ &nbsp;&nbsp;
14
+ <a href="https://www.sciencedirect.com/journal/neurocomputing">
15
+ <img src="https://img.shields.io/badge/Neurocomputing-Published-success?style=flat-square" alt="Neurocomputing Badge"/>
16
+ </a>
17
+ &nbsp;&nbsp;
18
+ <a href="https://www.elsevier.com/">
19
+ <img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
20
+ </a>
21
+ </div>
22
+
23
+ ---
24
+
25
+ ## Neurocomputing — PyTorch Implementation
26
+
27
+ This repository contains a PyTorch implementation of **UD7** of the paper:
28
+
29
+ > **Provable Generalization of Clipped Double Q-Learning for Variance Reduction and Sample Efficiency**
30
+ > Jangwon Kim, Jiseok Jeong, Soohee Han
31
+ > *Neurocomputing*, Volume 673, 7 April 2026, 132772
32
+
33
+ ### Paper Link
34
+ https://www.sciencedirect.com/science/article/abs/pii/S0925231226001694
35
+
36
+ ---
37
+
38
+ **UD7** is an off-policy actor–critic algorithm that builds on a TD7-style training pipeline, while replacing the critic target **formulation** with **UBOC**.
39
+
40
+ ---
41
+
42
+ ## 1) Background: Clipped Double Q-Learning (CDQ)
43
+
44
+ Clipped double Q-learning is a widely-used bias correction in actor-critic methods (e.g., TD3). It maintains **two critics** and uses the **minimum** of the two as the TD target:
45
+
46
+ $$
47
+ y_{\text{CDQ}}(s_t,a_t)=r_t+\gamma \min_{i\in\{1,2\}} \bar Q_i(s_{t+1}, a_{t+1})
48
+ $$
49
+
50
+ ### Strengths (why CDQ is popular)
51
+ - **Effective overestimation control:** taking a minimum is conservative, often preventing exploding Q-values.
52
+ - **Robust baseline behavior:** works well across many continuous-control tasks.
53
+
54
+ ### Limitations (what the paper highlights)
55
+ - **High variance:** when critics are poorly learned early on, the min operator can yield high-variance TD targets, destabilizing TD learning and reducing sample efficiency.
56
+
57
+ **UBOC is motivated by a concrete question:**
58
+ > Can we obtain **the same expected target value as CDQ**, but with **smaller variance**?
59
+
60
+ ---
61
+
62
+ ## 2) UBOC: Uncertainty-Based Overestimation Correction (Detailed)
63
+
64
+ UBOC views the critic outputs as a **distribution of Q estimates** (because function approximation is noisy).
65
+ Instead of using `min(Q1, Q2)`, UBOC uses **N critics** to estimate:
66
+ - a **mean** \(m\),
67
+ - an **(unbiased) standard deviation** \(\hat s\),
68
+ and then forms a corrected value:
69
+
70
+ $$
71
+ Q_{\text{corrected}} = m - x\cdot \hat s
72
+ $$
73
+
74
+ where \(x>0\) controls conservativeness.
75
+
76
+ ### 2.1 Expectation equivalence to clipped double-Q
77
+
78
+ Under the assumption that critic estimates behave like i.i.d. samples from a normal distribution, we can derive:
79
+
80
+ $$
81
+ \mathbb{E}\left[\min(Q_A, Q_B)\right]=\mathbb{E}\left[m - \frac{\hat s}{\sqrt{\pi}}\right]
82
+ $$
83
+
84
+ This is the key insight:
85
+ - choosing $$x=1/\sqrt{\pi}$$ makes the corrected estimate **match CDQ in expectation**.
86
+
87
+ ### 2.2 Variance reduction (provable)
88
+
89
+ We can further prove that the estimator
90
+
91
+ $$
92
+ m - \frac{\hat s}{\sqrt{\pi}}
93
+ $$
94
+
95
+ has **strictly smaller variance** than the CDQ minimum-based target, and the **variance gap is strictly positive for all $$N\ge 2\$$**.
96
+
97
+ As $$N\to\infty$$, the maximum achievable variance reduction is upper-bounded by:
98
+
99
+ $$
100
+ \sigma^2\left(1-\frac{1}{\pi}\right)
101
+ $$
102
+
103
+ **It means that**
104
+ - UBOC does not only “bias-correct”; it **reduces noise** in TD targets.
105
+ - This is especially important early in training, where noisy targets can derail learning.
106
+
107
+ ### 2.3 UBOC TD target (what you implement)
108
+
109
+ Using N target critics $$Q_1,\dots, Q_N$$, compute:
110
+
111
+
112
+ **Mean**
113
+
114
+ $$
115
+ m(s,a) = \frac{1}{N}\sum_{i=1}^N Q_i(s,a)
116
+ $$
117
+
118
+ **Unbiased variance (Approximation)**
119
+
120
+ $$
121
+ \hat v(s,a)=\frac{1}{N-1}\sum_{i=1}^N \left( Q_i(s,a)-m(s,a)\right)^2
122
+ $$
123
+
124
+ Then the **UBOC target** is:
125
+
126
+ $$
127
+ y_{\text{UBOC}}(s_t,a_t)=r_t + \gamma\left(m(s_{t+1},a_{t+1}) - \sqrt{\frac{\hat v(s_{t+1},a_{t+1})}{\pi}}\right)
128
+ $$
129
+
130
+ where $$a_{t+1}$$ can be computed with target policy smoothing.
131
+
132
+ This gives a *dynamic* bias correction driven by critic uncertainty.
133
+
134
+ ---
135
+
136
+ ## 3) UD7: TD7 + UBOC Targets
137
+
138
+ **UD7** integrates UBOC into a TD7-style pipeline and emphasizes strong sample efficiency.
139
+
140
+ - UD7 uses the TD7 background for practical stability/efficiency.
141
+ - **The main difference from TD7 is the critic training target:** UD7 uses **UBOC targets** and a multi-critic ensemble (commonly **N=5**).
142
+
143
+ > If you already have a TD7 baseline, UD7 is best viewed as:
144
+ > **“swap the target rule + use N critics, then keep the rest of the training recipe.”**
145
+
146
+ ---
147
+
148
+ ## 4) Performance
149
+
150
+ <div align="center">
151
+ <img src="figures/performance.png" alt="Fig. 1 — Performance comparison on MuJoCo benchmarks" width="800"/>
152
+ </div>
153
+
154
+ ---
155
+
156
+ ## 5) Computational Overhead
157
+
158
+ Runtime figure (tested on RTX 3090 Ti + Intel i7-12700):
159
+
160
+ <div align="center">
161
+ <img src="figures/runtime.png" alt="Fig. 2 — Runtime comparison" width="300"/>
162
+ </div>
163
+
164
+ ---
165
+
166
+ ## Citation
167
+ ```
168
+ @article{kim2026provable,
169
+ title={Provable generalization of clipped double Q-learning for variance reduction and sample efficiency},
170
+ author={Kim, Jangwon and Jeong, Jiseok and Han, Soohee},
171
+ journal={Neurocomputing},
172
+ pages={132772},
173
+ year={2026},
174
+ publisher={Elsevier}
175
+ }
176
+ ```
figures/performance.png ADDED

Git LFS Details

  • SHA256: 4c69f84acacf25e82bc66fc8cf37cdf1ea3d8b4b64425db865f28019fb6418e8
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
figures/runtime.png ADDED
figures/tmp.md ADDED
@@ -0,0 +1 @@
 
 
1
+
main.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from ud7 import UD7
4
+ from trainer import Trainer
5
+ from utils import set_seed, make_env
6
+
7
+
8
+ def get_parameters():
9
+ parser = argparse.ArgumentParser()
10
+
11
+ # Environment Setting
12
+ parser.add_argument('--env-name', default='Humanoid-v4')
13
+ parser.add_argument('--random-seed', default=-1, type=int)
14
+
15
+ # UBOC
16
+ parser.add_argument('--num_critics', default=5, type=int)
17
+
18
+ # Checkpointing
19
+ parser.add_argument('--use_checkpoints', default=True, type=bool)
20
+ parser.add_argument('--max-eps-when-checkpointing', default=20, type=int)
21
+ parser.add_argument('--steps-before-checkpointing', default=75e4, type=int)
22
+ parser.add_argument('--reset-weight', default=0.9, type=float)
23
+
24
+ # LAP
25
+ parser.add_argument('--alpha', default=0.4, type=float)
26
+ parser.add_argument('--min_priority', default=1, type=float)
27
+
28
+ # Generic
29
+ parser.add_argument('--target-update-rate', default=250, type=int)
30
+ parser.add_argument('--start-steps', default=25e3, type=int)
31
+ parser.add_argument('--max-steps', default=5000000, type=int)
32
+ parser.add_argument('--zs-dim', default=256, type=int)
33
+ parser.add_argument('--critic-hidden-dims', default=(256, 256))
34
+ parser.add_argument('--policy-hidden-dims', default=(256, 256))
35
+ parser.add_argument('--encoder-hidden-dims', default=(256, 256))
36
+ parser.add_argument('--hidden-dims', default=(256, 256))
37
+ parser.add_argument('--batch-size', default=256, type=int)
38
+ parser.add_argument('--buffer-size', default=1000000, type=int)
39
+ parser.add_argument('--policy-update-delay', default=2)
40
+ parser.add_argument('--gamma', default=0.99, type=float)
41
+ parser.add_argument('--actor-lr', default=0.0003, type=float)
42
+ parser.add_argument('--critic-lr', default=0.0003, type=float)
43
+ parser.add_argument('--encoder-lr', default=0.0003, type=float)
44
+
45
+ # TD3
46
+ parser.add_argument('--act-noise-scale', default=0.1, type=float)
47
+ parser.add_argument('--target-noise-scale', default=0.2, type=float)
48
+ parser.add_argument('--target-noise-clip', default=0.5, type=float)
49
+
50
+ # Log & Evaluation
51
+ parser.add_argument('--show-loss', default=False, type=bool)
52
+ parser.add_argument('--eval_flag', default=True, type=bool)
53
+ parser.add_argument('--eval-freq', default=5000, type=int)
54
+ parser.add_argument('--eval-episode', default=10, type=int)
55
+
56
+ param = parser.parse_args()
57
+
58
+ return param
59
+
60
+
61
+ def main(args):
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+ random_seed = set_seed(args.random_seed)
64
+ env, eval_env = make_env(args.env_name, random_seed)
65
+
66
+ state_dim = env.observation_space.shape[0]
67
+ action_dim = env.action_space.shape[0]
68
+ action_bound = [env.action_space.low[0], env.action_space.high[0]]
69
+
70
+ agent = UD7(state_dim, action_dim, action_bound, device, args)
71
+
72
+ trainer = Trainer(env, eval_env, agent, args)
73
+ trainer.run()
74
+
75
+
76
+ if __name__ == '__main__':
77
+ args = get_parameters()
78
+ main(args)
network.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils import weight_init, AvgL1Norm
5
+
6
+
7
+ class EnsembleQNet(nn.Module):
8
+ def __init__(self, num_critics, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
9
+ super(EnsembleQNet, self).__init__()
10
+ self.device = device
11
+ self.activation_fc = activation_fc
12
+
13
+ self.num_critics = num_critics
14
+
15
+ self.q_nets = nn.ModuleList()
16
+ for _ in range(self.num_critics):
17
+ q_net = self._build_q_net(state_dim, action_dim, zs_dim, hidden_dims)
18
+ self.q_nets.append(q_net)
19
+
20
+ self.apply(weight_init)
21
+
22
+ def _build_q_net(self, state_dim, action_dim, zs_dim, hidden_dims):
23
+ q_net = nn.ModuleDict({
24
+ 's_input_layer': nn.Linear(state_dim + action_dim, hidden_dims[0]),
25
+ 'emb_input_layer': nn.Linear(2 * zs_dim + hidden_dims[0], hidden_dims[0]),
26
+ 'emb_hidden_layers': nn.ModuleList([
27
+ nn.Linear(hidden_dims[i], hidden_dims[i + 1]) for i in range(len(hidden_dims) - 1)
28
+ ]),
29
+ 'output_layer': nn.Linear(hidden_dims[-1], 1)
30
+ })
31
+ return q_net
32
+
33
+ def _format(self, state, action):
34
+ x, u = state, action
35
+ if not isinstance(x, torch.Tensor):
36
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
37
+ x = x.unsqueeze(0)
38
+
39
+ if not isinstance(u, torch.Tensor):
40
+ u = torch.tensor(u, device=self.device, dtype=torch.float32)
41
+ u = u.unsqueeze(0)
42
+
43
+ return x, u
44
+
45
+ def forward(self, state, action, zsa, zs):
46
+ s, a = self._format(state, action)
47
+ sa = torch.cat([s, a], dim=1)
48
+ embeddings = torch.cat([zsa, zs], dim=1)
49
+
50
+ q_values = []
51
+ for q_net in self.q_nets:
52
+ q = AvgL1Norm(q_net['s_input_layer'](sa))
53
+ q = torch.cat([q, embeddings], dim=1)
54
+ q = self.activation_fc(q_net['emb_input_layer'](q))
55
+ for hidden_layer in q_net['emb_hidden_layers']:
56
+ q = self.activation_fc(hidden_layer(q))
57
+ q = q_net['output_layer'](q)
58
+ q_values.append(q)
59
+
60
+ return torch.cat(q_values, dim=1)
61
+
62
+
63
+ class Policy(nn.Module):
64
+ def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.relu):
65
+ super(Policy, self).__init__()
66
+ self.device = device
67
+ self.apply(weight_init)
68
+
69
+ self.activation_fc = activation_fc
70
+
71
+ self.s_input_layer = nn.Linear(state_dim, hidden_dims[0])
72
+ self.zss_input_layer = nn.Linear(zs_dim + hidden_dims[0], hidden_dims[0])
73
+ self.zss_hidden_layers = nn.ModuleList()
74
+ for i in range(len(hidden_dims)-1):
75
+ hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
76
+ self.zss_hidden_layers.append(hidden_layer)
77
+ self.zss_output_layer = nn.Linear(hidden_dims[-1], action_dim)
78
+
79
+ def _format(self, state):
80
+ x = state
81
+ if not isinstance(x, torch.Tensor):
82
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
83
+ x = x.unsqueeze(0)
84
+
85
+ return x
86
+
87
+ def forward(self, state, zs):
88
+ state = self._format(state)
89
+
90
+ state = AvgL1Norm(self.s_input_layer(state))
91
+ zss = torch.cat([state, zs], 1)
92
+
93
+ zss = self.activation_fc(self.zss_input_layer(zss))
94
+ for i, hidden_layer in enumerate(self.zss_hidden_layers):
95
+ zss = self.activation_fc(hidden_layer(zss))
96
+ zss = self.zss_output_layer(zss)
97
+ action = torch.tanh(zss)
98
+ return action
99
+
100
+
101
+ class Encoder(nn.Module):
102
+ def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
103
+ super(Encoder, self).__init__()
104
+ self.device = device
105
+ self.activation_fc = activation_fc
106
+
107
+ self.s_encoder_input_layer = nn.Linear(state_dim, hidden_dims[0])
108
+ self.s_encoder_hidden_layers = nn.ModuleList()
109
+ for i in range(len(hidden_dims)-1):
110
+ hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
111
+ self.s_encoder_hidden_layers.append(hidden_layer)
112
+ self.s_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
113
+
114
+ self.zsa_encoder_input_layer = nn.Linear(zs_dim + action_dim, hidden_dims[0])
115
+ self.zsa_encoder_hidden_layers = nn.ModuleList()
116
+ for i in range(len(hidden_dims)-1):
117
+ hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
118
+ self.zsa_encoder_hidden_layers.append(hidden_layer)
119
+ self.zsa_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
120
+
121
+ def _format(self, state):
122
+ x = state
123
+ if not isinstance(x, torch.Tensor):
124
+ x = torch.tensor(x, device=self.device, dtype=torch.float32)
125
+ x = x.unsqueeze(0)
126
+ return x
127
+
128
+ def zs(self, state):
129
+ state = self._format(state)
130
+
131
+ zs = self.activation_fc(self.s_encoder_input_layer(state))
132
+ for i, hidden_layer in enumerate(self.s_encoder_hidden_layers):
133
+ zs = self.activation_fc(hidden_layer(zs))
134
+ zs = AvgL1Norm(self.s_encoder_output_layer(zs))
135
+ return zs
136
+
137
+ def zsa(self, zs, action):
138
+ action = self._format(action)
139
+ zsa = torch.cat([zs, action], 1)
140
+
141
+ zsa = self.activation_fc(self.zsa_encoder_input_layer(zsa))
142
+ for i, hidden_layer in enumerate(self.zsa_encoder_hidden_layers):
143
+ zsa = self.activation_fc(hidden_layer(zsa))
144
+ zsa = self.zsa_encoder_output_layer(zsa)
145
+ return zsa
replay_memory.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class LAP:
6
+ def __init__(self, state_dim, action_dim, device, capacity=1e6, normalize_action=True, max_action=1, prioritized=True):
7
+ # Set the device
8
+ self.device = device
9
+
10
+ # Set the replay buffer capacity
11
+ self.capacity = int(capacity)
12
+ self.size = 0
13
+ self.position = 0
14
+
15
+ # Set the action normalization factor
16
+ self.do_normalize_action = normalize_action
17
+ self.normalize_action = max_action if normalize_action else 1
18
+ self.max_action = max_action
19
+
20
+ # Set the prioritized flag
21
+ self.prioritized = prioritized
22
+ if prioritized:
23
+ self.priority = torch.zeros(self.capacity, device=device)
24
+ self.max_priority = 1
25
+
26
+ # Initialize the replay buffer
27
+ self.state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
28
+ self.action_buffer = np.empty(shape=(self.capacity, action_dim), dtype=np.float32)
29
+ self.reward_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
30
+ self.next_state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
31
+ self.done_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
32
+
33
+ def push(self, state, action, reward, next_state, done):
34
+ self.state_buffer[self.position] = state
35
+ self.action_buffer[self.position] = action / self.normalize_action
36
+ self.reward_buffer[self.position] = reward
37
+ self.next_state_buffer[self.position] = next_state
38
+ self.done_buffer[self.position] = done
39
+
40
+ if self.prioritized:
41
+ self.priority[self.position] = self.max_priority
42
+
43
+ self.position = (self.position + 1) % self.capacity
44
+ self.size = min(self.size + 1, self.capacity)
45
+
46
+ def sample(self, batch_size):
47
+ if self.prioritized:
48
+ csum = torch.cumsum(self.priority[:self.size], 0)
49
+ val = torch.rand(size=(batch_size,), device=self.device) * csum[-1]
50
+ self.ind = torch.searchsorted(csum, val).cpu().data.numpy()
51
+ else:
52
+ self.ind = np.random.randint(0, self.size, size=batch_size)
53
+
54
+ states = torch.FloatTensor(self.state_buffer[self.ind]).to(self.device)
55
+ actions = torch.FloatTensor(self.action_buffer[self.ind]).to(self.device)
56
+ rewards = torch.FloatTensor(self.reward_buffer[self.ind]).to(self.device)
57
+ next_states = torch.FloatTensor(self.next_state_buffer[self.ind]).to(self.device)
58
+ dones = torch.FloatTensor(self.done_buffer[self.ind]).to(self.device)
59
+
60
+ return states, actions, rewards, next_states, dones
61
+
62
+ def update_priority(self, priority):
63
+ self.priority[self.ind] = priority.reshape(-1).detach()
64
+ self.max_priority = max(float(priority.max()), self.max_priority)
65
+
66
+ def reset_max_priority(self):
67
+ self.max_priority = float(self.priority[:self.size].max())
trainer.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Trainer:
4
+ def __init__(self, env, eval_env, agent, args):
5
+ self.args = args
6
+
7
+ self.agent = agent
8
+ self.env_name = args.env_name
9
+ self.env = env
10
+ self.eval_env = eval_env
11
+
12
+ self.start_steps = args.start_steps
13
+ self.max_steps = args.max_steps
14
+ self.batch_size = args.batch_size
15
+ self.target_noise_scale = args.target_noise_scale
16
+
17
+ self.eval_flag = args.eval_flag
18
+ self.eval_episode = args.eval_episode
19
+ self.eval_freq = args.eval_freq
20
+
21
+ self.episode = 0
22
+ self.episode_reward = 0
23
+ self.total_steps = 0
24
+ self.eval_num = 0
25
+ self.finish_flag = False
26
+
27
+ self.target_noise_scale = args.target_noise_scale
28
+ self.policy_update_delay = args.policy_update_delay
29
+
30
+ def evaluate(self):
31
+ # Evaluate process
32
+ self.eval_num += 1
33
+ reward_list = []
34
+
35
+ for epi in range(self.eval_episode):
36
+ epi_reward = 0
37
+ state, _ = self.eval_env.reset()
38
+
39
+ done = False
40
+
41
+ while not done:
42
+ action = self.agent.get_action(state, use_checkpoint=self.args.use_checkpoints, add_noise=False)
43
+ next_state, reward, terminated, truncated, _ = self.eval_env.step(action)
44
+ done = terminated or truncated
45
+ epi_reward += reward
46
+ state = next_state
47
+ reward_list.append(epi_reward)
48
+
49
+ print("Eval | total_step {} | episode {} | Average Reward {:.2f} | Max reward: {:.2f} | "
50
+ "Min reward: {:.2f}".format(self.total_steps, self.episode, sum(reward_list)/len(reward_list),
51
+ max(reward_list), min(reward_list), np.std(reward_list)))
52
+
53
+ def run(self):
54
+ # Train-process start.
55
+ allow_train = False
56
+
57
+ while not self.finish_flag:
58
+ self.episode += 1
59
+ done = False
60
+ ep_total_reward, ep_timesteps = 0, 0
61
+
62
+ state, _ = self.env.reset()
63
+ # Episode start.
64
+ while not done:
65
+ self.total_steps += 1
66
+ ep_timesteps += 1
67
+
68
+ if allow_train:
69
+ action = self.agent.get_action(state, use_checkpoint=False, add_noise=True)
70
+ else:
71
+ action = self.env.action_space.sample()
72
+ next_state, reward, terminated, truncated, _ = self.env.step(action)
73
+ done = terminated or truncated
74
+
75
+ ep_total_reward += reward
76
+
77
+ done_mask = 0.0 if ep_timesteps == self.env._max_episode_steps else float(done)
78
+ self.agent.buffer.push(state, action, reward, next_state, done_mask)
79
+
80
+ state = next_state
81
+
82
+ if allow_train and not self.args.use_checkpoints:
83
+ actor_loss, critic_loss, encoder_loss = self.agent.train()
84
+ # Print loss.
85
+ if self.args.show_loss:
86
+ print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Encoder loss {:.3f}"
87
+ .format(actor_loss, critic_loss, encoder_loss))
88
+
89
+ if done:
90
+ if allow_train and self.args.use_checkpoints:
91
+ self.agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)
92
+
93
+ if self.total_steps >= self.args.start_steps:
94
+ allow_train = True
95
+
96
+ # Evaluation.
97
+ if self.eval_flag and self.total_steps % self.eval_freq == 0:
98
+ self.evaluate()
99
+
100
+ # Raise finish_flag.
101
+ if self.total_steps == self.max_steps:
102
+ self.finish_flag = True
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
ud7.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import copy
5
+ from replay_memory import LAP
6
+ from network import Policy, Encoder, EnsembleQNet
7
+ from utils import hard_update, LAP_huber
8
+
9
+
10
+ class UD7:
11
+ def __init__(self, state_dim, action_dim, action_bound, device, args):
12
+ self.args = args
13
+
14
+ self.state_dim = state_dim
15
+ self.action_dim = action_dim
16
+
17
+ self.device = device
18
+ self.buffer = LAP(self.state_dim, self.action_dim, device, args.buffer_size, normalize_action=True,
19
+ max_action=action_bound[1], prioritized=True)
20
+ self.batch_size = args.batch_size
21
+
22
+ self.gamma = args.gamma
23
+ self.act_noise_scale = args.act_noise_scale
24
+
25
+ self.num_critics = args.num_critics
26
+
27
+ self.actor = Policy(self.state_dim, self.action_dim, self.device, args.zs_dim, args.policy_hidden_dims).to(self.device)
28
+ self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr)
29
+ self.target_actor = Policy(self.state_dim, self.action_dim, self.device, args.zs_dim, args.policy_hidden_dims).to(self.device)
30
+
31
+ self.critic = EnsembleQNet(self.num_critics, self.state_dim, self.action_dim,
32
+ self.device, args.zs_dim, args.critic_hidden_dims).to(self.device)
33
+ self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr)
34
+ self.target_critic = EnsembleQNet(self.num_critics, self.state_dim, self.action_dim,
35
+ self.device, args.zs_dim, args.critic_hidden_dims).to(self.device)
36
+
37
+ self.encoder = Encoder(state_dim, action_dim, self.device, args.zs_dim, args.encoder_hidden_dims).to(self.device)
38
+ self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=args.encoder_lr)
39
+ self.fixed_encoder = copy.deepcopy(self.encoder)
40
+ self.fixed_encoder_target = copy.deepcopy(self.encoder)
41
+
42
+ self.checkpoint_actor = copy.deepcopy(self.actor)
43
+ self.checkpoint_encoder = copy.deepcopy(self.encoder)
44
+
45
+ self.training_steps = 0
46
+
47
+ self.max_action = action_bound[1]
48
+
49
+ # Checkpointing tracked values
50
+ self.eps_since_update = 0
51
+ self.timesteps_since_update = 0
52
+ self.max_eps_before_update = 1
53
+ self.min_return = 1e8
54
+ self.best_min_return = -1e8
55
+
56
+ # Value clipping tracked values
57
+ self.max = -1e8
58
+ self.min = 1e8
59
+ self.max_target = 0
60
+ self.min_target = 0
61
+
62
+ hard_update(self.actor, self.target_actor)
63
+ hard_update(self.critic, self.target_critic)
64
+
65
+ def get_action(self, state, use_checkpoint=False, add_noise=True):
66
+ with torch.no_grad():
67
+ if add_noise:
68
+ if use_checkpoint:
69
+ zs = self.checkpoint_encoder.zs(state)
70
+ action = self.checkpoint_actor(state, zs)
71
+ action = action + torch.randn_like(action) * self.act_noise_scale
72
+ action = np.clip(action.cpu().numpy()[0], -1, 1)
73
+ else:
74
+ zs = self.fixed_encoder.zs(state)
75
+ action = self.actor(state, zs)
76
+ action = action + torch.randn_like(action) * self.act_noise_scale
77
+ action = np.clip(action.cpu().numpy()[0], -1, 1)
78
+ else:
79
+ if use_checkpoint:
80
+ zs = self.checkpoint_encoder.zs(state)
81
+ action = self.checkpoint_actor(state, zs).cpu().numpy()[0]
82
+ else:
83
+ zs = self.fixed_encoder.zs(state)
84
+ action = self.actor(state, zs).cpu().numpy()[0]
85
+
86
+ action = action * self.max_action
87
+ return action
88
+
89
+ def train(self):
90
+ self.training_steps += 1
91
+
92
+ # Sample from LAP
93
+ states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
94
+
95
+ # Update Encoder
96
+ with torch.no_grad():
97
+ next_zs = self.encoder.zs(next_states)
98
+ zs = self.encoder.zs(states)
99
+ pred_zs = self.encoder.zsa(zs, actions)
100
+ encoder_loss = F.mse_loss(pred_zs, next_zs)
101
+
102
+ self.encoder_optimizer.zero_grad()
103
+ encoder_loss.backward()
104
+ self.encoder_optimizer.step()
105
+
106
+ # Update Critic
107
+ with torch.no_grad():
108
+ fixed_target_zs = self.fixed_encoder_target.zs(next_states)
109
+
110
+ target_act_noise = (torch.randn_like(actions) * self.args.target_noise_scale).clamp(-self.args.target_noise_clip, self.args.target_noise_clip).to(self.device)
111
+
112
+ if self.buffer.do_normalize_action is True:
113
+ next_target_actions = (self.target_actor(next_states, fixed_target_zs) + target_act_noise).clamp(-1, 1)
114
+ else:
115
+ next_target_actions = (self.target_actor(next_states, fixed_target_zs) + target_act_noise).clamp(-self.max_action, self.max_action)
116
+
117
+ fixed_target_zsa = self.fixed_encoder_target.zsa(fixed_target_zs, next_target_actions)
118
+
119
+ Q_target = self.target_critic(next_states, next_target_actions, fixed_target_zsa, fixed_target_zs)
120
+ m = Q_target.mean(dim=1, keepdim=True) # Sample mean
121
+ b = Q_target.var(dim=1, unbiased=True, keepdim=True) # Sample variance
122
+ Bias_Corrected_Q_target = m - 0.5641896 * torch.sqrt(b) # bias-corrected target Q
123
+
124
+ Q_target = rewards + (1 - dones) * self.gamma * Bias_Corrected_Q_target.clamp(self.min_target, self.max_target)
125
+
126
+ self.max = max(self.max, float(Q_target.max()))
127
+ self.min = min(self.min, float(Q_target.min()))
128
+
129
+ fixed_zs = self.fixed_encoder.zs(states)
130
+ fixed_zsa = self.fixed_encoder.zsa(fixed_zs, actions)
131
+
132
+ Q = self.critic(states, actions, fixed_zsa, fixed_zs)
133
+
134
+ td_loss = (Q - Q_target).abs()
135
+ critic_loss = LAP_huber(td_loss)
136
+
137
+ self.critic_optimizer.zero_grad()
138
+ critic_loss.backward()
139
+ self.critic_optimizer.step()
140
+
141
+ # Update LAP
142
+ priority = td_loss.max(1)[0].clamp(min=self.args.min_priority).pow(self.args.alpha)
143
+ self.buffer.update_priority(priority)
144
+
145
+ # Update Actor
146
+ if self.training_steps % self.args.policy_update_delay == 0:
147
+ actor_actions = self.actor(states, fixed_zs)
148
+ fixed_zsa = self.fixed_encoder.zsa(fixed_zs, actor_actions)
149
+ Q = self.critic(states, actor_actions, fixed_zsa, fixed_zs)
150
+
151
+ actor_loss = -Q.mean(dim=1, keepdim=True).mean()
152
+
153
+ self.actor_optimizer.zero_grad()
154
+ actor_loss.backward()
155
+ self.actor_optimizer.step()
156
+ else:
157
+ actor_loss = torch.tensor(0.0)
158
+
159
+ # Update Iteration
160
+ if self.training_steps % self.args.target_update_rate == 0:
161
+ self.target_actor.load_state_dict(self.actor.state_dict())
162
+ self.target_critic.load_state_dict(self.critic.state_dict())
163
+ self.fixed_encoder_target.load_state_dict(self.fixed_encoder.state_dict())
164
+ self.fixed_encoder.load_state_dict(self.encoder.state_dict())
165
+
166
+ self.buffer.reset_max_priority()
167
+
168
+ self.max_target = self.max
169
+ self.min_target = self.min
170
+
171
+ return actor_loss.item(), critic_loss.item(), encoder_loss.item()
172
+
173
+ def maybe_train_and_checkpoint(self, ep_timesteps, ep_return):
174
+ self.eps_since_update += 1
175
+ self.timesteps_since_update += ep_timesteps
176
+
177
+ self.min_return = min(self.min_return, ep_return)
178
+
179
+ # End evaluation of current policy early
180
+ if self.min_return < self.best_min_return:
181
+ self.train_and_reset()
182
+
183
+ # Update checkpoint
184
+ elif self.eps_since_update == self.max_eps_before_update:
185
+ self.best_min_return = self.min_return
186
+ self.checkpoint_actor.load_state_dict(self.actor.state_dict())
187
+ self.checkpoint_encoder.load_state_dict(self.fixed_encoder.state_dict())
188
+
189
+ self.train_and_reset()
190
+
191
+ # Batch training
192
+ def train_and_reset(self):
193
+ for _ in range(self.timesteps_since_update):
194
+ if self.training_steps == self.args.steps_before_checkpointing:
195
+ self.best_min_return *= self.args.reset_weight
196
+ self.max_eps_before_update = self.args.max_eps_when_checkpointing
197
+
198
+ self.train()
199
+
200
+ self.eps_since_update = 0
201
+ self.timesteps_since_update = 0
202
+ self.min_return = 1e8
utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def AvgL1Norm(x, eps=1e-8):
8
+ return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)
9
+
10
+
11
+ def LAP_huber(x, min_priority=1):
12
+ return torch.where(x < min_priority, 0.5 * x.pow(2), min_priority * x).sum(1).mean()
13
+
14
+
15
+ def weight_init(m):
16
+ """Custom weight init for Conv2D and Linear layers.
17
+ Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
18
+
19
+ if isinstance(m, nn.Linear):
20
+ nn.init.orthogonal_(m.weight.data)
21
+ m.bias.data.fill_(0.0)
22
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
23
+ # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
24
+ assert m.weight.size(2) == m.weight.size(3)
25
+ m.weight.data.fill_(0.0)
26
+ m.bias.data.fill_(0.0)
27
+ mid = m.weight.size(2) // 2
28
+ gain = nn.init.calculate_gain('relu')
29
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
30
+
31
+
32
+ def hard_update(network, target_network):
33
+ for param, target_param in zip(network.parameters(), target_network.parameters()):
34
+ target_param.data.copy_(param.data)
35
+
36
+
37
+ def soft_update(network, target_network, tau):
38
+ for param, target_param in zip(network.parameters(), target_network.parameters()):
39
+ target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
40
+
41
+
42
+ def set_seed(random_seed):
43
+ if random_seed <= 0:
44
+ random_seed = np.random.randint(1, 9999)
45
+ else:
46
+ random_seed = random_seed
47
+
48
+ torch.manual_seed(random_seed)
49
+ np.random.seed(random_seed)
50
+ random.seed(random_seed)
51
+ return random_seed
52
+
53
+
54
+ def make_env(env_name, random_seed):
55
+ import gymnasium as gym
56
+ # openai gym
57
+ env = gym.make(env_name)
58
+ env.action_space.seed(random_seed)
59
+
60
+ eval_env = gym.make(env_name)
61
+ eval_env.action_space.seed(random_seed + 100)
62
+
63
+ return env, eval_env