Upload 10 files
Browse files- .gitattributes +1 -0
- README.md +176 -0
- figures/performance.png +3 -0
- figures/runtime.png +0 -0
- figures/tmp.md +1 -0
- main.py +78 -0
- network.py +145 -0
- replay_memory.py +67 -0
- trainer.py +112 -0
- ud7.py +202 -0
- utils.py +63 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/performance.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<div align="center">
|
| 3 |
+
<h1>UD7</h1>
|
| 4 |
+
<h3>Provable Generalization of Clipped Double Q-Learning for Variance Reduction and Sample Efficiency</h3>
|
| 5 |
+
|
| 6 |
+
<a href="https://www.python.org/">
|
| 7 |
+
<img src="https://img.shields.io/badge/Python-3.7+-blue?logo=python&style=flat-square" alt="Python Badge"/>
|
| 8 |
+
</a>
|
| 9 |
+
|
| 10 |
+
<a href="https://pytorch.org/">
|
| 11 |
+
<img src="https://img.shields.io/badge/PyTorch-1.8+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
|
| 12 |
+
</a>
|
| 13 |
+
|
| 14 |
+
<a href="https://www.sciencedirect.com/journal/neurocomputing">
|
| 15 |
+
<img src="https://img.shields.io/badge/Neurocomputing-Published-success?style=flat-square" alt="Neurocomputing Badge"/>
|
| 16 |
+
</a>
|
| 17 |
+
|
| 18 |
+
<a href="https://www.elsevier.com/">
|
| 19 |
+
<img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
|
| 20 |
+
</a>
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Neurocomputing — PyTorch Implementation
|
| 26 |
+
|
| 27 |
+
This repository contains a PyTorch implementation of **UD7** of the paper:
|
| 28 |
+
|
| 29 |
+
> **Provable Generalization of Clipped Double Q-Learning for Variance Reduction and Sample Efficiency**
|
| 30 |
+
> Jangwon Kim, Jiseok Jeong, Soohee Han
|
| 31 |
+
> *Neurocomputing*, Volume 673, 7 April 2026, 132772
|
| 32 |
+
|
| 33 |
+
### Paper Link
|
| 34 |
+
https://www.sciencedirect.com/science/article/abs/pii/S0925231226001694
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
**UD7** is an off-policy actor–critic algorithm that builds on a TD7-style training pipeline, while replacing the critic target **formulation** with **UBOC**.
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 1) Background: Clipped Double Q-Learning (CDQ)
|
| 43 |
+
|
| 44 |
+
Clipped double Q-learning is a widely-used bias correction in actor-critic methods (e.g., TD3). It maintains **two critics** and uses the **minimum** of the two as the TD target:
|
| 45 |
+
|
| 46 |
+
$$
|
| 47 |
+
y_{\text{CDQ}}(s_t,a_t)=r_t+\gamma \min_{i\in\{1,2\}} \bar Q_i(s_{t+1}, a_{t+1})
|
| 48 |
+
$$
|
| 49 |
+
|
| 50 |
+
### Strengths (why CDQ is popular)
|
| 51 |
+
- **Effective overestimation control:** taking a minimum is conservative, often preventing exploding Q-values.
|
| 52 |
+
- **Robust baseline behavior:** works well across many continuous-control tasks.
|
| 53 |
+
|
| 54 |
+
### Limitations (what the paper highlights)
|
| 55 |
+
- **High variance:** when critics are poorly learned early on, the min operator can yield high-variance TD targets, destabilizing TD learning and reducing sample efficiency.
|
| 56 |
+
|
| 57 |
+
**UBOC is motivated by a concrete question:**
|
| 58 |
+
> Can we obtain **the same expected target value as CDQ**, but with **smaller variance**?
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## 2) UBOC: Uncertainty-Based Overestimation Correction (Detailed)
|
| 63 |
+
|
| 64 |
+
UBOC views the critic outputs as a **distribution of Q estimates** (because function approximation is noisy).
|
| 65 |
+
Instead of using `min(Q1, Q2)`, UBOC uses **N critics** to estimate:
|
| 66 |
+
- a **mean** \(m\),
|
| 67 |
+
- an **(unbiased) standard deviation** \(\hat s\),
|
| 68 |
+
and then forms a corrected value:
|
| 69 |
+
|
| 70 |
+
$$
|
| 71 |
+
Q_{\text{corrected}} = m - x\cdot \hat s
|
| 72 |
+
$$
|
| 73 |
+
|
| 74 |
+
where \(x>0\) controls conservativeness.
|
| 75 |
+
|
| 76 |
+
### 2.1 Expectation equivalence to clipped double-Q
|
| 77 |
+
|
| 78 |
+
Under the assumption that critic estimates behave like i.i.d. samples from a normal distribution, we can derive:
|
| 79 |
+
|
| 80 |
+
$$
|
| 81 |
+
\mathbb{E}\left[\min(Q_A, Q_B)\right]=\mathbb{E}\left[m - \frac{\hat s}{\sqrt{\pi}}\right]
|
| 82 |
+
$$
|
| 83 |
+
|
| 84 |
+
This is the key insight:
|
| 85 |
+
- choosing $$x=1/\sqrt{\pi}$$ makes the corrected estimate **match CDQ in expectation**.
|
| 86 |
+
|
| 87 |
+
### 2.2 Variance reduction (provable)
|
| 88 |
+
|
| 89 |
+
We can further prove that the estimator
|
| 90 |
+
|
| 91 |
+
$$
|
| 92 |
+
m - \frac{\hat s}{\sqrt{\pi}}
|
| 93 |
+
$$
|
| 94 |
+
|
| 95 |
+
has **strictly smaller variance** than the CDQ minimum-based target, and the **variance gap is strictly positive for all $$N\ge 2\$$**.
|
| 96 |
+
|
| 97 |
+
As $$N\to\infty$$, the maximum achievable variance reduction is upper-bounded by:
|
| 98 |
+
|
| 99 |
+
$$
|
| 100 |
+
\sigma^2\left(1-\frac{1}{\pi}\right)
|
| 101 |
+
$$
|
| 102 |
+
|
| 103 |
+
**It means that**
|
| 104 |
+
- UBOC does not only “bias-correct”; it **reduces noise** in TD targets.
|
| 105 |
+
- This is especially important early in training, where noisy targets can derail learning.
|
| 106 |
+
|
| 107 |
+
### 2.3 UBOC TD target (what you implement)
|
| 108 |
+
|
| 109 |
+
Using N target critics $$Q_1,\dots, Q_N$$, compute:
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
**Mean**
|
| 113 |
+
|
| 114 |
+
$$
|
| 115 |
+
m(s,a) = \frac{1}{N}\sum_{i=1}^N Q_i(s,a)
|
| 116 |
+
$$
|
| 117 |
+
|
| 118 |
+
**Unbiased variance (Approximation)**
|
| 119 |
+
|
| 120 |
+
$$
|
| 121 |
+
\hat v(s,a)=\frac{1}{N-1}\sum_{i=1}^N \left( Q_i(s,a)-m(s,a)\right)^2
|
| 122 |
+
$$
|
| 123 |
+
|
| 124 |
+
Then the **UBOC target** is:
|
| 125 |
+
|
| 126 |
+
$$
|
| 127 |
+
y_{\text{UBOC}}(s_t,a_t)=r_t + \gamma\left(m(s_{t+1},a_{t+1}) - \sqrt{\frac{\hat v(s_{t+1},a_{t+1})}{\pi}}\right)
|
| 128 |
+
$$
|
| 129 |
+
|
| 130 |
+
where $$a_{t+1}$$ can be computed with target policy smoothing.
|
| 131 |
+
|
| 132 |
+
This gives a *dynamic* bias correction driven by critic uncertainty.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## 3) UD7: TD7 + UBOC Targets
|
| 137 |
+
|
| 138 |
+
**UD7** integrates UBOC into a TD7-style pipeline and emphasizes strong sample efficiency.
|
| 139 |
+
|
| 140 |
+
- UD7 uses the TD7 background for practical stability/efficiency.
|
| 141 |
+
- **The main difference from TD7 is the critic training target:** UD7 uses **UBOC targets** and a multi-critic ensemble (commonly **N=5**).
|
| 142 |
+
|
| 143 |
+
> If you already have a TD7 baseline, UD7 is best viewed as:
|
| 144 |
+
> **“swap the target rule + use N critics, then keep the rest of the training recipe.”**
|
| 145 |
+
|
| 146 |
+
---
|
| 147 |
+
|
| 148 |
+
## 4) Performance
|
| 149 |
+
|
| 150 |
+
<div align="center">
|
| 151 |
+
<img src="figures/performance.png" alt="Fig. 1 — Performance comparison on MuJoCo benchmarks" width="800"/>
|
| 152 |
+
</div>
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## 5) Computational Overhead
|
| 157 |
+
|
| 158 |
+
Runtime figure (tested on RTX 3090 Ti + Intel i7-12700):
|
| 159 |
+
|
| 160 |
+
<div align="center">
|
| 161 |
+
<img src="figures/runtime.png" alt="Fig. 2 — Runtime comparison" width="300"/>
|
| 162 |
+
</div>
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## Citation
|
| 167 |
+
```
|
| 168 |
+
@article{kim2026provable,
|
| 169 |
+
title={Provable generalization of clipped double Q-learning for variance reduction and sample efficiency},
|
| 170 |
+
author={Kim, Jangwon and Jeong, Jiseok and Han, Soohee},
|
| 171 |
+
journal={Neurocomputing},
|
| 172 |
+
pages={132772},
|
| 173 |
+
year={2026},
|
| 174 |
+
publisher={Elsevier}
|
| 175 |
+
}
|
| 176 |
+
```
|
figures/performance.png
ADDED
|
Git LFS Details
|
figures/runtime.png
ADDED
|
figures/tmp.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
main.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
from ud7 import UD7
|
| 4 |
+
from trainer import Trainer
|
| 5 |
+
from utils import set_seed, make_env
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_parameters():
|
| 9 |
+
parser = argparse.ArgumentParser()
|
| 10 |
+
|
| 11 |
+
# Environment Setting
|
| 12 |
+
parser.add_argument('--env-name', default='Humanoid-v4')
|
| 13 |
+
parser.add_argument('--random-seed', default=-1, type=int)
|
| 14 |
+
|
| 15 |
+
# UBOC
|
| 16 |
+
parser.add_argument('--num_critics', default=5, type=int)
|
| 17 |
+
|
| 18 |
+
# Checkpointing
|
| 19 |
+
parser.add_argument('--use_checkpoints', default=True, type=bool)
|
| 20 |
+
parser.add_argument('--max-eps-when-checkpointing', default=20, type=int)
|
| 21 |
+
parser.add_argument('--steps-before-checkpointing', default=75e4, type=int)
|
| 22 |
+
parser.add_argument('--reset-weight', default=0.9, type=float)
|
| 23 |
+
|
| 24 |
+
# LAP
|
| 25 |
+
parser.add_argument('--alpha', default=0.4, type=float)
|
| 26 |
+
parser.add_argument('--min_priority', default=1, type=float)
|
| 27 |
+
|
| 28 |
+
# Generic
|
| 29 |
+
parser.add_argument('--target-update-rate', default=250, type=int)
|
| 30 |
+
parser.add_argument('--start-steps', default=25e3, type=int)
|
| 31 |
+
parser.add_argument('--max-steps', default=5000000, type=int)
|
| 32 |
+
parser.add_argument('--zs-dim', default=256, type=int)
|
| 33 |
+
parser.add_argument('--critic-hidden-dims', default=(256, 256))
|
| 34 |
+
parser.add_argument('--policy-hidden-dims', default=(256, 256))
|
| 35 |
+
parser.add_argument('--encoder-hidden-dims', default=(256, 256))
|
| 36 |
+
parser.add_argument('--hidden-dims', default=(256, 256))
|
| 37 |
+
parser.add_argument('--batch-size', default=256, type=int)
|
| 38 |
+
parser.add_argument('--buffer-size', default=1000000, type=int)
|
| 39 |
+
parser.add_argument('--policy-update-delay', default=2)
|
| 40 |
+
parser.add_argument('--gamma', default=0.99, type=float)
|
| 41 |
+
parser.add_argument('--actor-lr', default=0.0003, type=float)
|
| 42 |
+
parser.add_argument('--critic-lr', default=0.0003, type=float)
|
| 43 |
+
parser.add_argument('--encoder-lr', default=0.0003, type=float)
|
| 44 |
+
|
| 45 |
+
# TD3
|
| 46 |
+
parser.add_argument('--act-noise-scale', default=0.1, type=float)
|
| 47 |
+
parser.add_argument('--target-noise-scale', default=0.2, type=float)
|
| 48 |
+
parser.add_argument('--target-noise-clip', default=0.5, type=float)
|
| 49 |
+
|
| 50 |
+
# Log & Evaluation
|
| 51 |
+
parser.add_argument('--show-loss', default=False, type=bool)
|
| 52 |
+
parser.add_argument('--eval_flag', default=True, type=bool)
|
| 53 |
+
parser.add_argument('--eval-freq', default=5000, type=int)
|
| 54 |
+
parser.add_argument('--eval-episode', default=10, type=int)
|
| 55 |
+
|
| 56 |
+
param = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
return param
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main(args):
|
| 62 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 63 |
+
random_seed = set_seed(args.random_seed)
|
| 64 |
+
env, eval_env = make_env(args.env_name, random_seed)
|
| 65 |
+
|
| 66 |
+
state_dim = env.observation_space.shape[0]
|
| 67 |
+
action_dim = env.action_space.shape[0]
|
| 68 |
+
action_bound = [env.action_space.low[0], env.action_space.high[0]]
|
| 69 |
+
|
| 70 |
+
agent = UD7(state_dim, action_dim, action_bound, device, args)
|
| 71 |
+
|
| 72 |
+
trainer = Trainer(env, eval_env, agent, args)
|
| 73 |
+
trainer.run()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
args = get_parameters()
|
| 78 |
+
main(args)
|
network.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils import weight_init, AvgL1Norm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EnsembleQNet(nn.Module):
|
| 8 |
+
def __init__(self, num_critics, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
|
| 9 |
+
super(EnsembleQNet, self).__init__()
|
| 10 |
+
self.device = device
|
| 11 |
+
self.activation_fc = activation_fc
|
| 12 |
+
|
| 13 |
+
self.num_critics = num_critics
|
| 14 |
+
|
| 15 |
+
self.q_nets = nn.ModuleList()
|
| 16 |
+
for _ in range(self.num_critics):
|
| 17 |
+
q_net = self._build_q_net(state_dim, action_dim, zs_dim, hidden_dims)
|
| 18 |
+
self.q_nets.append(q_net)
|
| 19 |
+
|
| 20 |
+
self.apply(weight_init)
|
| 21 |
+
|
| 22 |
+
def _build_q_net(self, state_dim, action_dim, zs_dim, hidden_dims):
|
| 23 |
+
q_net = nn.ModuleDict({
|
| 24 |
+
's_input_layer': nn.Linear(state_dim + action_dim, hidden_dims[0]),
|
| 25 |
+
'emb_input_layer': nn.Linear(2 * zs_dim + hidden_dims[0], hidden_dims[0]),
|
| 26 |
+
'emb_hidden_layers': nn.ModuleList([
|
| 27 |
+
nn.Linear(hidden_dims[i], hidden_dims[i + 1]) for i in range(len(hidden_dims) - 1)
|
| 28 |
+
]),
|
| 29 |
+
'output_layer': nn.Linear(hidden_dims[-1], 1)
|
| 30 |
+
})
|
| 31 |
+
return q_net
|
| 32 |
+
|
| 33 |
+
def _format(self, state, action):
|
| 34 |
+
x, u = state, action
|
| 35 |
+
if not isinstance(x, torch.Tensor):
|
| 36 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 37 |
+
x = x.unsqueeze(0)
|
| 38 |
+
|
| 39 |
+
if not isinstance(u, torch.Tensor):
|
| 40 |
+
u = torch.tensor(u, device=self.device, dtype=torch.float32)
|
| 41 |
+
u = u.unsqueeze(0)
|
| 42 |
+
|
| 43 |
+
return x, u
|
| 44 |
+
|
| 45 |
+
def forward(self, state, action, zsa, zs):
|
| 46 |
+
s, a = self._format(state, action)
|
| 47 |
+
sa = torch.cat([s, a], dim=1)
|
| 48 |
+
embeddings = torch.cat([zsa, zs], dim=1)
|
| 49 |
+
|
| 50 |
+
q_values = []
|
| 51 |
+
for q_net in self.q_nets:
|
| 52 |
+
q = AvgL1Norm(q_net['s_input_layer'](sa))
|
| 53 |
+
q = torch.cat([q, embeddings], dim=1)
|
| 54 |
+
q = self.activation_fc(q_net['emb_input_layer'](q))
|
| 55 |
+
for hidden_layer in q_net['emb_hidden_layers']:
|
| 56 |
+
q = self.activation_fc(hidden_layer(q))
|
| 57 |
+
q = q_net['output_layer'](q)
|
| 58 |
+
q_values.append(q)
|
| 59 |
+
|
| 60 |
+
return torch.cat(q_values, dim=1)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Policy(nn.Module):
|
| 64 |
+
def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.relu):
|
| 65 |
+
super(Policy, self).__init__()
|
| 66 |
+
self.device = device
|
| 67 |
+
self.apply(weight_init)
|
| 68 |
+
|
| 69 |
+
self.activation_fc = activation_fc
|
| 70 |
+
|
| 71 |
+
self.s_input_layer = nn.Linear(state_dim, hidden_dims[0])
|
| 72 |
+
self.zss_input_layer = nn.Linear(zs_dim + hidden_dims[0], hidden_dims[0])
|
| 73 |
+
self.zss_hidden_layers = nn.ModuleList()
|
| 74 |
+
for i in range(len(hidden_dims)-1):
|
| 75 |
+
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
|
| 76 |
+
self.zss_hidden_layers.append(hidden_layer)
|
| 77 |
+
self.zss_output_layer = nn.Linear(hidden_dims[-1], action_dim)
|
| 78 |
+
|
| 79 |
+
def _format(self, state):
|
| 80 |
+
x = state
|
| 81 |
+
if not isinstance(x, torch.Tensor):
|
| 82 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 83 |
+
x = x.unsqueeze(0)
|
| 84 |
+
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
def forward(self, state, zs):
|
| 88 |
+
state = self._format(state)
|
| 89 |
+
|
| 90 |
+
state = AvgL1Norm(self.s_input_layer(state))
|
| 91 |
+
zss = torch.cat([state, zs], 1)
|
| 92 |
+
|
| 93 |
+
zss = self.activation_fc(self.zss_input_layer(zss))
|
| 94 |
+
for i, hidden_layer in enumerate(self.zss_hidden_layers):
|
| 95 |
+
zss = self.activation_fc(hidden_layer(zss))
|
| 96 |
+
zss = self.zss_output_layer(zss)
|
| 97 |
+
action = torch.tanh(zss)
|
| 98 |
+
return action
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Encoder(nn.Module):
|
| 102 |
+
def __init__(self, state_dim, action_dim, device, zs_dim=256, hidden_dims=(256, 256), activation_fc=F.elu):
|
| 103 |
+
super(Encoder, self).__init__()
|
| 104 |
+
self.device = device
|
| 105 |
+
self.activation_fc = activation_fc
|
| 106 |
+
|
| 107 |
+
self.s_encoder_input_layer = nn.Linear(state_dim, hidden_dims[0])
|
| 108 |
+
self.s_encoder_hidden_layers = nn.ModuleList()
|
| 109 |
+
for i in range(len(hidden_dims)-1):
|
| 110 |
+
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
|
| 111 |
+
self.s_encoder_hidden_layers.append(hidden_layer)
|
| 112 |
+
self.s_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
|
| 113 |
+
|
| 114 |
+
self.zsa_encoder_input_layer = nn.Linear(zs_dim + action_dim, hidden_dims[0])
|
| 115 |
+
self.zsa_encoder_hidden_layers = nn.ModuleList()
|
| 116 |
+
for i in range(len(hidden_dims)-1):
|
| 117 |
+
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
|
| 118 |
+
self.zsa_encoder_hidden_layers.append(hidden_layer)
|
| 119 |
+
self.zsa_encoder_output_layer = nn.Linear(hidden_dims[-1], zs_dim)
|
| 120 |
+
|
| 121 |
+
def _format(self, state):
|
| 122 |
+
x = state
|
| 123 |
+
if not isinstance(x, torch.Tensor):
|
| 124 |
+
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
| 125 |
+
x = x.unsqueeze(0)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
def zs(self, state):
|
| 129 |
+
state = self._format(state)
|
| 130 |
+
|
| 131 |
+
zs = self.activation_fc(self.s_encoder_input_layer(state))
|
| 132 |
+
for i, hidden_layer in enumerate(self.s_encoder_hidden_layers):
|
| 133 |
+
zs = self.activation_fc(hidden_layer(zs))
|
| 134 |
+
zs = AvgL1Norm(self.s_encoder_output_layer(zs))
|
| 135 |
+
return zs
|
| 136 |
+
|
| 137 |
+
def zsa(self, zs, action):
|
| 138 |
+
action = self._format(action)
|
| 139 |
+
zsa = torch.cat([zs, action], 1)
|
| 140 |
+
|
| 141 |
+
zsa = self.activation_fc(self.zsa_encoder_input_layer(zsa))
|
| 142 |
+
for i, hidden_layer in enumerate(self.zsa_encoder_hidden_layers):
|
| 143 |
+
zsa = self.activation_fc(hidden_layer(zsa))
|
| 144 |
+
zsa = self.zsa_encoder_output_layer(zsa)
|
| 145 |
+
return zsa
|
replay_memory.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LAP:
|
| 6 |
+
def __init__(self, state_dim, action_dim, device, capacity=1e6, normalize_action=True, max_action=1, prioritized=True):
|
| 7 |
+
# Set the device
|
| 8 |
+
self.device = device
|
| 9 |
+
|
| 10 |
+
# Set the replay buffer capacity
|
| 11 |
+
self.capacity = int(capacity)
|
| 12 |
+
self.size = 0
|
| 13 |
+
self.position = 0
|
| 14 |
+
|
| 15 |
+
# Set the action normalization factor
|
| 16 |
+
self.do_normalize_action = normalize_action
|
| 17 |
+
self.normalize_action = max_action if normalize_action else 1
|
| 18 |
+
self.max_action = max_action
|
| 19 |
+
|
| 20 |
+
# Set the prioritized flag
|
| 21 |
+
self.prioritized = prioritized
|
| 22 |
+
if prioritized:
|
| 23 |
+
self.priority = torch.zeros(self.capacity, device=device)
|
| 24 |
+
self.max_priority = 1
|
| 25 |
+
|
| 26 |
+
# Initialize the replay buffer
|
| 27 |
+
self.state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
|
| 28 |
+
self.action_buffer = np.empty(shape=(self.capacity, action_dim), dtype=np.float32)
|
| 29 |
+
self.reward_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
|
| 30 |
+
self.next_state_buffer = np.empty(shape=(self.capacity, state_dim), dtype=np.float32)
|
| 31 |
+
self.done_buffer = np.empty(shape=(self.capacity, 1), dtype=np.float32)
|
| 32 |
+
|
| 33 |
+
def push(self, state, action, reward, next_state, done):
|
| 34 |
+
self.state_buffer[self.position] = state
|
| 35 |
+
self.action_buffer[self.position] = action / self.normalize_action
|
| 36 |
+
self.reward_buffer[self.position] = reward
|
| 37 |
+
self.next_state_buffer[self.position] = next_state
|
| 38 |
+
self.done_buffer[self.position] = done
|
| 39 |
+
|
| 40 |
+
if self.prioritized:
|
| 41 |
+
self.priority[self.position] = self.max_priority
|
| 42 |
+
|
| 43 |
+
self.position = (self.position + 1) % self.capacity
|
| 44 |
+
self.size = min(self.size + 1, self.capacity)
|
| 45 |
+
|
| 46 |
+
def sample(self, batch_size):
|
| 47 |
+
if self.prioritized:
|
| 48 |
+
csum = torch.cumsum(self.priority[:self.size], 0)
|
| 49 |
+
val = torch.rand(size=(batch_size,), device=self.device) * csum[-1]
|
| 50 |
+
self.ind = torch.searchsorted(csum, val).cpu().data.numpy()
|
| 51 |
+
else:
|
| 52 |
+
self.ind = np.random.randint(0, self.size, size=batch_size)
|
| 53 |
+
|
| 54 |
+
states = torch.FloatTensor(self.state_buffer[self.ind]).to(self.device)
|
| 55 |
+
actions = torch.FloatTensor(self.action_buffer[self.ind]).to(self.device)
|
| 56 |
+
rewards = torch.FloatTensor(self.reward_buffer[self.ind]).to(self.device)
|
| 57 |
+
next_states = torch.FloatTensor(self.next_state_buffer[self.ind]).to(self.device)
|
| 58 |
+
dones = torch.FloatTensor(self.done_buffer[self.ind]).to(self.device)
|
| 59 |
+
|
| 60 |
+
return states, actions, rewards, next_states, dones
|
| 61 |
+
|
| 62 |
+
def update_priority(self, priority):
|
| 63 |
+
self.priority[self.ind] = priority.reshape(-1).detach()
|
| 64 |
+
self.max_priority = max(float(priority.max()), self.max_priority)
|
| 65 |
+
|
| 66 |
+
def reset_max_priority(self):
|
| 67 |
+
self.max_priority = float(self.priority[:self.size].max())
|
trainer.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Trainer:
|
| 4 |
+
def __init__(self, env, eval_env, agent, args):
|
| 5 |
+
self.args = args
|
| 6 |
+
|
| 7 |
+
self.agent = agent
|
| 8 |
+
self.env_name = args.env_name
|
| 9 |
+
self.env = env
|
| 10 |
+
self.eval_env = eval_env
|
| 11 |
+
|
| 12 |
+
self.start_steps = args.start_steps
|
| 13 |
+
self.max_steps = args.max_steps
|
| 14 |
+
self.batch_size = args.batch_size
|
| 15 |
+
self.target_noise_scale = args.target_noise_scale
|
| 16 |
+
|
| 17 |
+
self.eval_flag = args.eval_flag
|
| 18 |
+
self.eval_episode = args.eval_episode
|
| 19 |
+
self.eval_freq = args.eval_freq
|
| 20 |
+
|
| 21 |
+
self.episode = 0
|
| 22 |
+
self.episode_reward = 0
|
| 23 |
+
self.total_steps = 0
|
| 24 |
+
self.eval_num = 0
|
| 25 |
+
self.finish_flag = False
|
| 26 |
+
|
| 27 |
+
self.target_noise_scale = args.target_noise_scale
|
| 28 |
+
self.policy_update_delay = args.policy_update_delay
|
| 29 |
+
|
| 30 |
+
def evaluate(self):
|
| 31 |
+
# Evaluate process
|
| 32 |
+
self.eval_num += 1
|
| 33 |
+
reward_list = []
|
| 34 |
+
|
| 35 |
+
for epi in range(self.eval_episode):
|
| 36 |
+
epi_reward = 0
|
| 37 |
+
state, _ = self.eval_env.reset()
|
| 38 |
+
|
| 39 |
+
done = False
|
| 40 |
+
|
| 41 |
+
while not done:
|
| 42 |
+
action = self.agent.get_action(state, use_checkpoint=self.args.use_checkpoints, add_noise=False)
|
| 43 |
+
next_state, reward, terminated, truncated, _ = self.eval_env.step(action)
|
| 44 |
+
done = terminated or truncated
|
| 45 |
+
epi_reward += reward
|
| 46 |
+
state = next_state
|
| 47 |
+
reward_list.append(epi_reward)
|
| 48 |
+
|
| 49 |
+
print("Eval | total_step {} | episode {} | Average Reward {:.2f} | Max reward: {:.2f} | "
|
| 50 |
+
"Min reward: {:.2f}".format(self.total_steps, self.episode, sum(reward_list)/len(reward_list),
|
| 51 |
+
max(reward_list), min(reward_list), np.std(reward_list)))
|
| 52 |
+
|
| 53 |
+
def run(self):
|
| 54 |
+
# Train-process start.
|
| 55 |
+
allow_train = False
|
| 56 |
+
|
| 57 |
+
while not self.finish_flag:
|
| 58 |
+
self.episode += 1
|
| 59 |
+
done = False
|
| 60 |
+
ep_total_reward, ep_timesteps = 0, 0
|
| 61 |
+
|
| 62 |
+
state, _ = self.env.reset()
|
| 63 |
+
# Episode start.
|
| 64 |
+
while not done:
|
| 65 |
+
self.total_steps += 1
|
| 66 |
+
ep_timesteps += 1
|
| 67 |
+
|
| 68 |
+
if allow_train:
|
| 69 |
+
action = self.agent.get_action(state, use_checkpoint=False, add_noise=True)
|
| 70 |
+
else:
|
| 71 |
+
action = self.env.action_space.sample()
|
| 72 |
+
next_state, reward, terminated, truncated, _ = self.env.step(action)
|
| 73 |
+
done = terminated or truncated
|
| 74 |
+
|
| 75 |
+
ep_total_reward += reward
|
| 76 |
+
|
| 77 |
+
done_mask = 0.0 if ep_timesteps == self.env._max_episode_steps else float(done)
|
| 78 |
+
self.agent.buffer.push(state, action, reward, next_state, done_mask)
|
| 79 |
+
|
| 80 |
+
state = next_state
|
| 81 |
+
|
| 82 |
+
if allow_train and not self.args.use_checkpoints:
|
| 83 |
+
actor_loss, critic_loss, encoder_loss = self.agent.train()
|
| 84 |
+
# Print loss.
|
| 85 |
+
if self.args.show_loss:
|
| 86 |
+
print("Loss | Actor loss {:.3f} | Critic loss {:.3f} | Encoder loss {:.3f}"
|
| 87 |
+
.format(actor_loss, critic_loss, encoder_loss))
|
| 88 |
+
|
| 89 |
+
if done:
|
| 90 |
+
if allow_train and self.args.use_checkpoints:
|
| 91 |
+
self.agent.maybe_train_and_checkpoint(ep_timesteps, ep_total_reward)
|
| 92 |
+
|
| 93 |
+
if self.total_steps >= self.args.start_steps:
|
| 94 |
+
allow_train = True
|
| 95 |
+
|
| 96 |
+
# Evaluation.
|
| 97 |
+
if self.eval_flag and self.total_steps % self.eval_freq == 0:
|
| 98 |
+
self.evaluate()
|
| 99 |
+
|
| 100 |
+
# Raise finish_flag.
|
| 101 |
+
if self.total_steps == self.max_steps:
|
| 102 |
+
self.finish_flag = True
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
ud7.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import copy
|
| 5 |
+
from replay_memory import LAP
|
| 6 |
+
from network import Policy, Encoder, EnsembleQNet
|
| 7 |
+
from utils import hard_update, LAP_huber
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UD7:
|
| 11 |
+
def __init__(self, state_dim, action_dim, action_bound, device, args):
|
| 12 |
+
self.args = args
|
| 13 |
+
|
| 14 |
+
self.state_dim = state_dim
|
| 15 |
+
self.action_dim = action_dim
|
| 16 |
+
|
| 17 |
+
self.device = device
|
| 18 |
+
self.buffer = LAP(self.state_dim, self.action_dim, device, args.buffer_size, normalize_action=True,
|
| 19 |
+
max_action=action_bound[1], prioritized=True)
|
| 20 |
+
self.batch_size = args.batch_size
|
| 21 |
+
|
| 22 |
+
self.gamma = args.gamma
|
| 23 |
+
self.act_noise_scale = args.act_noise_scale
|
| 24 |
+
|
| 25 |
+
self.num_critics = args.num_critics
|
| 26 |
+
|
| 27 |
+
self.actor = Policy(self.state_dim, self.action_dim, self.device, args.zs_dim, args.policy_hidden_dims).to(self.device)
|
| 28 |
+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr)
|
| 29 |
+
self.target_actor = Policy(self.state_dim, self.action_dim, self.device, args.zs_dim, args.policy_hidden_dims).to(self.device)
|
| 30 |
+
|
| 31 |
+
self.critic = EnsembleQNet(self.num_critics, self.state_dim, self.action_dim,
|
| 32 |
+
self.device, args.zs_dim, args.critic_hidden_dims).to(self.device)
|
| 33 |
+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr)
|
| 34 |
+
self.target_critic = EnsembleQNet(self.num_critics, self.state_dim, self.action_dim,
|
| 35 |
+
self.device, args.zs_dim, args.critic_hidden_dims).to(self.device)
|
| 36 |
+
|
| 37 |
+
self.encoder = Encoder(state_dim, action_dim, self.device, args.zs_dim, args.encoder_hidden_dims).to(self.device)
|
| 38 |
+
self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=args.encoder_lr)
|
| 39 |
+
self.fixed_encoder = copy.deepcopy(self.encoder)
|
| 40 |
+
self.fixed_encoder_target = copy.deepcopy(self.encoder)
|
| 41 |
+
|
| 42 |
+
self.checkpoint_actor = copy.deepcopy(self.actor)
|
| 43 |
+
self.checkpoint_encoder = copy.deepcopy(self.encoder)
|
| 44 |
+
|
| 45 |
+
self.training_steps = 0
|
| 46 |
+
|
| 47 |
+
self.max_action = action_bound[1]
|
| 48 |
+
|
| 49 |
+
# Checkpointing tracked values
|
| 50 |
+
self.eps_since_update = 0
|
| 51 |
+
self.timesteps_since_update = 0
|
| 52 |
+
self.max_eps_before_update = 1
|
| 53 |
+
self.min_return = 1e8
|
| 54 |
+
self.best_min_return = -1e8
|
| 55 |
+
|
| 56 |
+
# Value clipping tracked values
|
| 57 |
+
self.max = -1e8
|
| 58 |
+
self.min = 1e8
|
| 59 |
+
self.max_target = 0
|
| 60 |
+
self.min_target = 0
|
| 61 |
+
|
| 62 |
+
hard_update(self.actor, self.target_actor)
|
| 63 |
+
hard_update(self.critic, self.target_critic)
|
| 64 |
+
|
| 65 |
+
def get_action(self, state, use_checkpoint=False, add_noise=True):
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
if add_noise:
|
| 68 |
+
if use_checkpoint:
|
| 69 |
+
zs = self.checkpoint_encoder.zs(state)
|
| 70 |
+
action = self.checkpoint_actor(state, zs)
|
| 71 |
+
action = action + torch.randn_like(action) * self.act_noise_scale
|
| 72 |
+
action = np.clip(action.cpu().numpy()[0], -1, 1)
|
| 73 |
+
else:
|
| 74 |
+
zs = self.fixed_encoder.zs(state)
|
| 75 |
+
action = self.actor(state, zs)
|
| 76 |
+
action = action + torch.randn_like(action) * self.act_noise_scale
|
| 77 |
+
action = np.clip(action.cpu().numpy()[0], -1, 1)
|
| 78 |
+
else:
|
| 79 |
+
if use_checkpoint:
|
| 80 |
+
zs = self.checkpoint_encoder.zs(state)
|
| 81 |
+
action = self.checkpoint_actor(state, zs).cpu().numpy()[0]
|
| 82 |
+
else:
|
| 83 |
+
zs = self.fixed_encoder.zs(state)
|
| 84 |
+
action = self.actor(state, zs).cpu().numpy()[0]
|
| 85 |
+
|
| 86 |
+
action = action * self.max_action
|
| 87 |
+
return action
|
| 88 |
+
|
| 89 |
+
def train(self):
|
| 90 |
+
self.training_steps += 1
|
| 91 |
+
|
| 92 |
+
# Sample from LAP
|
| 93 |
+
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
|
| 94 |
+
|
| 95 |
+
# Update Encoder
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
next_zs = self.encoder.zs(next_states)
|
| 98 |
+
zs = self.encoder.zs(states)
|
| 99 |
+
pred_zs = self.encoder.zsa(zs, actions)
|
| 100 |
+
encoder_loss = F.mse_loss(pred_zs, next_zs)
|
| 101 |
+
|
| 102 |
+
self.encoder_optimizer.zero_grad()
|
| 103 |
+
encoder_loss.backward()
|
| 104 |
+
self.encoder_optimizer.step()
|
| 105 |
+
|
| 106 |
+
# Update Critic
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
fixed_target_zs = self.fixed_encoder_target.zs(next_states)
|
| 109 |
+
|
| 110 |
+
target_act_noise = (torch.randn_like(actions) * self.args.target_noise_scale).clamp(-self.args.target_noise_clip, self.args.target_noise_clip).to(self.device)
|
| 111 |
+
|
| 112 |
+
if self.buffer.do_normalize_action is True:
|
| 113 |
+
next_target_actions = (self.target_actor(next_states, fixed_target_zs) + target_act_noise).clamp(-1, 1)
|
| 114 |
+
else:
|
| 115 |
+
next_target_actions = (self.target_actor(next_states, fixed_target_zs) + target_act_noise).clamp(-self.max_action, self.max_action)
|
| 116 |
+
|
| 117 |
+
fixed_target_zsa = self.fixed_encoder_target.zsa(fixed_target_zs, next_target_actions)
|
| 118 |
+
|
| 119 |
+
Q_target = self.target_critic(next_states, next_target_actions, fixed_target_zsa, fixed_target_zs)
|
| 120 |
+
m = Q_target.mean(dim=1, keepdim=True) # Sample mean
|
| 121 |
+
b = Q_target.var(dim=1, unbiased=True, keepdim=True) # Sample variance
|
| 122 |
+
Bias_Corrected_Q_target = m - 0.5641896 * torch.sqrt(b) # bias-corrected target Q
|
| 123 |
+
|
| 124 |
+
Q_target = rewards + (1 - dones) * self.gamma * Bias_Corrected_Q_target.clamp(self.min_target, self.max_target)
|
| 125 |
+
|
| 126 |
+
self.max = max(self.max, float(Q_target.max()))
|
| 127 |
+
self.min = min(self.min, float(Q_target.min()))
|
| 128 |
+
|
| 129 |
+
fixed_zs = self.fixed_encoder.zs(states)
|
| 130 |
+
fixed_zsa = self.fixed_encoder.zsa(fixed_zs, actions)
|
| 131 |
+
|
| 132 |
+
Q = self.critic(states, actions, fixed_zsa, fixed_zs)
|
| 133 |
+
|
| 134 |
+
td_loss = (Q - Q_target).abs()
|
| 135 |
+
critic_loss = LAP_huber(td_loss)
|
| 136 |
+
|
| 137 |
+
self.critic_optimizer.zero_grad()
|
| 138 |
+
critic_loss.backward()
|
| 139 |
+
self.critic_optimizer.step()
|
| 140 |
+
|
| 141 |
+
# Update LAP
|
| 142 |
+
priority = td_loss.max(1)[0].clamp(min=self.args.min_priority).pow(self.args.alpha)
|
| 143 |
+
self.buffer.update_priority(priority)
|
| 144 |
+
|
| 145 |
+
# Update Actor
|
| 146 |
+
if self.training_steps % self.args.policy_update_delay == 0:
|
| 147 |
+
actor_actions = self.actor(states, fixed_zs)
|
| 148 |
+
fixed_zsa = self.fixed_encoder.zsa(fixed_zs, actor_actions)
|
| 149 |
+
Q = self.critic(states, actor_actions, fixed_zsa, fixed_zs)
|
| 150 |
+
|
| 151 |
+
actor_loss = -Q.mean(dim=1, keepdim=True).mean()
|
| 152 |
+
|
| 153 |
+
self.actor_optimizer.zero_grad()
|
| 154 |
+
actor_loss.backward()
|
| 155 |
+
self.actor_optimizer.step()
|
| 156 |
+
else:
|
| 157 |
+
actor_loss = torch.tensor(0.0)
|
| 158 |
+
|
| 159 |
+
# Update Iteration
|
| 160 |
+
if self.training_steps % self.args.target_update_rate == 0:
|
| 161 |
+
self.target_actor.load_state_dict(self.actor.state_dict())
|
| 162 |
+
self.target_critic.load_state_dict(self.critic.state_dict())
|
| 163 |
+
self.fixed_encoder_target.load_state_dict(self.fixed_encoder.state_dict())
|
| 164 |
+
self.fixed_encoder.load_state_dict(self.encoder.state_dict())
|
| 165 |
+
|
| 166 |
+
self.buffer.reset_max_priority()
|
| 167 |
+
|
| 168 |
+
self.max_target = self.max
|
| 169 |
+
self.min_target = self.min
|
| 170 |
+
|
| 171 |
+
return actor_loss.item(), critic_loss.item(), encoder_loss.item()
|
| 172 |
+
|
| 173 |
+
def maybe_train_and_checkpoint(self, ep_timesteps, ep_return):
|
| 174 |
+
self.eps_since_update += 1
|
| 175 |
+
self.timesteps_since_update += ep_timesteps
|
| 176 |
+
|
| 177 |
+
self.min_return = min(self.min_return, ep_return)
|
| 178 |
+
|
| 179 |
+
# End evaluation of current policy early
|
| 180 |
+
if self.min_return < self.best_min_return:
|
| 181 |
+
self.train_and_reset()
|
| 182 |
+
|
| 183 |
+
# Update checkpoint
|
| 184 |
+
elif self.eps_since_update == self.max_eps_before_update:
|
| 185 |
+
self.best_min_return = self.min_return
|
| 186 |
+
self.checkpoint_actor.load_state_dict(self.actor.state_dict())
|
| 187 |
+
self.checkpoint_encoder.load_state_dict(self.fixed_encoder.state_dict())
|
| 188 |
+
|
| 189 |
+
self.train_and_reset()
|
| 190 |
+
|
| 191 |
+
# Batch training
|
| 192 |
+
def train_and_reset(self):
|
| 193 |
+
for _ in range(self.timesteps_since_update):
|
| 194 |
+
if self.training_steps == self.args.steps_before_checkpointing:
|
| 195 |
+
self.best_min_return *= self.args.reset_weight
|
| 196 |
+
self.max_eps_before_update = self.args.max_eps_when_checkpointing
|
| 197 |
+
|
| 198 |
+
self.train()
|
| 199 |
+
|
| 200 |
+
self.eps_since_update = 0
|
| 201 |
+
self.timesteps_since_update = 0
|
| 202 |
+
self.min_return = 1e8
|
utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def AvgL1Norm(x, eps=1e-8):
|
| 8 |
+
return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def LAP_huber(x, min_priority=1):
|
| 12 |
+
return torch.where(x < min_priority, 0.5 * x.pow(2), min_priority * x).sum(1).mean()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def weight_init(m):
|
| 16 |
+
"""Custom weight init for Conv2D and Linear layers.
|
| 17 |
+
Reference: https://github.com/MishaLaskin/rad/blob/master/curl_sac.py"""
|
| 18 |
+
|
| 19 |
+
if isinstance(m, nn.Linear):
|
| 20 |
+
nn.init.orthogonal_(m.weight.data)
|
| 21 |
+
m.bias.data.fill_(0.0)
|
| 22 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 23 |
+
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
|
| 24 |
+
assert m.weight.size(2) == m.weight.size(3)
|
| 25 |
+
m.weight.data.fill_(0.0)
|
| 26 |
+
m.bias.data.fill_(0.0)
|
| 27 |
+
mid = m.weight.size(2) // 2
|
| 28 |
+
gain = nn.init.calculate_gain('relu')
|
| 29 |
+
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def hard_update(network, target_network):
|
| 33 |
+
for param, target_param in zip(network.parameters(), target_network.parameters()):
|
| 34 |
+
target_param.data.copy_(param.data)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def soft_update(network, target_network, tau):
|
| 38 |
+
for param, target_param in zip(network.parameters(), target_network.parameters()):
|
| 39 |
+
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def set_seed(random_seed):
|
| 43 |
+
if random_seed <= 0:
|
| 44 |
+
random_seed = np.random.randint(1, 9999)
|
| 45 |
+
else:
|
| 46 |
+
random_seed = random_seed
|
| 47 |
+
|
| 48 |
+
torch.manual_seed(random_seed)
|
| 49 |
+
np.random.seed(random_seed)
|
| 50 |
+
random.seed(random_seed)
|
| 51 |
+
return random_seed
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def make_env(env_name, random_seed):
|
| 55 |
+
import gymnasium as gym
|
| 56 |
+
# openai gym
|
| 57 |
+
env = gym.make(env_name)
|
| 58 |
+
env.action_space.seed(random_seed)
|
| 59 |
+
|
| 60 |
+
eval_env = gym.make(env_name)
|
| 61 |
+
eval_env.action_space.seed(random_seed + 100)
|
| 62 |
+
|
| 63 |
+
return env, eval_env
|