leonepson commited on
Commit
5960497
·
verified ·
1 Parent(s): 6876901

Upload 254 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +32 -0
  2. LICENSE +21 -0
  3. README.md +194 -12
  4. baselines/__init__.py +0 -0
  5. baselines/a2c/README.md +13 -0
  6. baselines/a2c/__init__.py +0 -0
  7. baselines/a2c/a2c.py +232 -0
  8. baselines/a2c/runner.py +76 -0
  9. baselines/a2c/utils.py +291 -0
  10. baselines/acer/README.md +6 -0
  11. baselines/acer/__init__.py +0 -0
  12. baselines/acer/acer.py +381 -0
  13. baselines/acer/buffer.py +156 -0
  14. baselines/acer/defaults.py +4 -0
  15. baselines/acer/policies.py +81 -0
  16. baselines/acer/runner.py +61 -0
  17. baselines/acktr/README.md +9 -0
  18. baselines/acktr/__init__.py +0 -0
  19. baselines/acktr/acktr.py +158 -0
  20. baselines/acktr/defaults.py +5 -0
  21. baselines/acktr/kfac.py +928 -0
  22. baselines/acktr/kfac_utils.py +86 -0
  23. baselines/acktr/utils.py +28 -0
  24. baselines/bench/__init__.py +3 -0
  25. baselines/bench/benchmarks.py +164 -0
  26. baselines/bench/monitor.py +162 -0
  27. baselines/bench/test_monitor.py +31 -0
  28. baselines/common/__init__.py +5 -0
  29. baselines/common/atari_wrappers.py +290 -0
  30. baselines/common/cg.py +34 -0
  31. baselines/common/cmd_util.py +206 -0
  32. baselines/common/console_util.py +80 -0
  33. baselines/common/dataset.py +60 -0
  34. baselines/common/distributions.py +355 -0
  35. baselines/common/input.py +65 -0
  36. baselines/common/math_util.py +85 -0
  37. baselines/common/misc_util.py +243 -0
  38. baselines/common/models.py +406 -0
  39. baselines/common/mpi_adam.py +103 -0
  40. baselines/common/mpi_adam_optimizer.py +90 -0
  41. baselines/common/mpi_fork.py +23 -0
  42. baselines/common/mpi_moments.py +61 -0
  43. baselines/common/mpi_running_mean_std.py +112 -0
  44. baselines/common/mpi_util.py +133 -0
  45. baselines/common/plot_util.py +434 -0
  46. baselines/common/policies.py +186 -0
  47. baselines/common/retro_wrappers.py +280 -0
  48. baselines/common/runners.py +19 -0
  49. baselines/common/running_mean_std.py +187 -0
  50. baselines/common/schedules.py +99 -0
.gitattributes CHANGED
@@ -36,3 +36,35 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
37
  checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
38
  checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
37
  checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
38
  checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
39
+ baselines/gail/result/halfcheetah-training.png filter=lfs diff=lfs merge=lfs -text
40
+ baselines/gail/result/hopper-training.png filter=lfs diff=lfs merge=lfs -text
41
+ baselines/gail/result/humanoid-training.png filter=lfs diff=lfs merge=lfs -text
42
+ baselines/gail/result/humanoidstandup-training.png filter=lfs diff=lfs merge=lfs -text
43
+ baselines/gail/result/walker2d-training.png filter=lfs diff=lfs merge=lfs -text
44
+ train_procgen/checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
45
+ train_procgen/checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
46
+ train_procgen/checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
47
+ videos/fruitbot_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
48
+ videos/fruitbot_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
49
+ videos/fruitbot_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
50
+ videos/fruitbot_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
51
+ videos/fruitbot_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
52
+ videos/fruitbot_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
53
+ videos/fruitbot_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
54
+ videos/fruitbot_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
55
+ videos/jumper_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
56
+ videos/jumper_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
57
+ videos/jumper_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
58
+ videos/jumper_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
59
+ videos/jumper_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
60
+ videos/jumper_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
61
+ videos/jumper_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
62
+ videos/jumper_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
63
+ videos/ninja_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
64
+ videos/ninja_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
65
+ videos/ninja_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
66
+ videos/ninja_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
67
+ videos/ninja_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
68
+ videos/ninja_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
69
+ videos/ninja_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
70
+ videos/ninja_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License
2
+
3
+ Copyright (c) 2019 OpenAI (http://openai.com)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,194 @@
1
- ---
2
- license: mit
3
- language:
4
- - en
5
- pipeline_tag: reinforcement-learning
6
- tags:
7
- - clustering
8
- - interpretablity
9
- - visualization
10
- - https://arxiv.org/abs/2409.17411
11
- paper: https://arxiv.org/abs/2409.17411
12
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Enhancing Interpretability in Deep Reinforcement Learning through Semantic Clustering
2
+
3
+ **Authors:** Liang Zhang, Justin Lieffers, Adarsh Pyarelal
4
+ **Conference:** NeurIPS 2025 Main Track
5
+ **Paper:** [arXiv:2409.17411](https://arxiv.org/abs/2409.17411)
6
+
7
+ This repository contains the official implementation of our research on enhancing interpretability in deep reinforcement learning through semantic clustering techniques. Our work extends the [OpenAI train-procgen](https://github.com/openai/train-procgen) framework to incorporate semantic clustering methods for improved understanding and visualization of learned policies in procedural environments.
8
+
9
+ ## 📋 Abstract
10
+
11
+ This work presents a novel approach to enhancing interpretability in deep reinforcement learning by leveraging semantic clustering techniques. We demonstrate how semantic clustering can provide insights into learned policies, enabling better understanding of agent behavior and decision-making processes in complex procedural environments.
12
+
13
+ ## 🚀 Quick Start
14
+
15
+ ### Installation
16
+
17
+ Prerequisite: Python 3.8.
18
+
19
+ 1. **Clone the repository:**
20
+ ```bash
21
+ git clone https://github.com/ualiangzhang/semantic_rl.git
22
+ cd semantic_rl
23
+ ```
24
+
25
+ 2. **Install dependencies (Python 3.8):**
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ 3. **Install Procgen environments:**
31
+ Follow the installation steps in the [Procgen repository](https://github.com/openai/procgen).
32
+
33
+ ### Basic Usage
34
+
35
+ **Train a semantic clustering model:**
36
+ ```bash
37
+ python -m train_procgen.train_sppo --env_name <ENV_NAME> --num_levels 0 --distribution_mode easy --timesteps_per_proc 25000000 --rand_seed <RAND_SEED>
38
+ ```
39
+
40
+ **Train a baseline model:**
41
+ ```bash
42
+ python -m train_procgen.train_ppo --env_name <ENV_NAME> --num_levels 0 --distribution_mode easy --timesteps_per_proc 25000000 --rand_seed <RAND_SEED>
43
+ ```
44
+
45
+ ## 📊 Visualization and Analysis
46
+
47
+ ### Performance Analysis
48
+
49
+ **Generate generalization figures for a single game:**
50
+ ```bash
51
+ cd train_procgen
52
+ python single_graph.py --env_name <ENV_NAME>
53
+ # Example:
54
+ python single_graph.py --env_name coinrun
55
+ ```
56
+
57
+ ### Semantic Clustering Visualization
58
+
59
+ **Generate embedding space visualizations:**
60
+ ```bash
61
+ python -m train_procgen.enjoy_sppo --env_name <ENV_NAME> --mode 1
62
+ ```
63
+
64
+ **Generate skill demonstration videos:**
65
+ ```bash
66
+ python -m train_procgen.enjoy_sppo --env_name <ENV_NAME> --mode 0
67
+ ```
68
+
69
+ **Interactive cluster exploration:**
70
+ ```bash
71
+ python -m train_procgen.hover_clusters --env_name <ENV_NAME>
72
+ # Example:
73
+ python -m train_procgen.hover_clusters --env_name fruitbot
74
+ ```
75
+
76
+ ## 🎮 Supported Environments
77
+
78
+ Our implementation supports four Procgen environments:
79
+ - **CoinRun**
80
+ - **FruitBot**
81
+ - **Jumper**
82
+ - **Ninja**
83
+
84
+ ## 🎬 Semantic Clustering Demonstration
85
+
86
+ ### Ninja Environment - 8 Semantic Clusters
87
+
88
+ The following videos demonstrate the 8 distinct semantic clusters learned by our model in the Ninja environment. Each cluster represents a different behavioral pattern and skill set:
89
+
90
+ #### 📹 Semantic Cluster Demonstrations
91
+
92
+ <table>
93
+ <tr>
94
+ <td align="center">
95
+ <strong>Cluster 0</strong><br/>
96
+ <img src="videos/ninja_skills/cluster_0.gif" width="180" height="180" />
97
+ </td>
98
+ <td align="center">
99
+ <strong>Cluster 1</strong><br/>
100
+ <img src="videos/ninja_skills/cluster_1.gif" width="180" height="180" />
101
+ </td>
102
+ <td align="center">
103
+ <strong>Cluster 2</strong><br/>
104
+ <img src="videos/ninja_skills/cluster_2.gif" width="180" height="180" />
105
+ </td>
106
+ <td align="center">
107
+ <strong>Cluster 3</strong><br/>
108
+ <img src="videos/ninja_skills/cluster_3.gif" width="180" height="180" />
109
+ </td>
110
+ </tr>
111
+ <tr>
112
+ <td align="center">
113
+ <strong>Cluster 4</strong><br/>
114
+ <img src="videos/ninja_skills/cluster_4.gif" width="180" height="180" />
115
+ </td>
116
+ <td align="center">
117
+ <strong>Cluster 5</strong><br/>
118
+ <img src="videos/ninja_skills/cluster_5.gif" width="180" height="180" />
119
+ </td>
120
+ <td align="center">
121
+ <strong>Cluster 6</strong><br/>
122
+ <img src="videos/ninja_skills/cluster_6.gif" width="180" height="180" />
123
+ </td>
124
+ <td align="center">
125
+ <strong>Cluster 7</strong><br/>
126
+ <img src="videos/ninja_skills/cluster_7.gif" width="180" height="180" />
127
+ </td>
128
+ </tr>
129
+ </table>
130
+
131
+ #### 🧭 Behavior Descriptions (Ninja)
132
+
133
+ | Cluster | Behavior |
134
+ |---------|----------|
135
+ | 0 | The agent starts by walking through the first platform and then performs a high jump to reach a higher ledge. |
136
+ | 1 | The agent makes small jumps in the middle of the scene. |
137
+ | 2 | Two interpretations are present: (1) the agent starts from the leftmost end of the scene and walks to the starting position of Cluster 0; (2) when there are no higher ledges to jump to, the agent begins from the scene, walks over the first platform, and prepares to jump to the subsequent ledge. |
138
+ | 3 | The agent walks on the ledge and prepares to jump to a higher ledge. |
139
+ | 4 | After performing a high jump, the agent loses sight of the ledge below. |
140
+ | 5 | The agent walks on the ledge and prepares to jump onto a ledge at the same height or lower. |
141
+ | 6 | The agent executes a high jump while keeping the ledge below in sight. |
142
+ | 7 | The agent moves towards the right edge of the scene and touches the mushroom. |
143
+
144
+ #### 📊 Alternative: Generate Your Own Videos
145
+
146
+ You can also generate these videos yourself using our code:
147
+
148
+ ```bash
149
+ # Generate Ninja skill cluster videos
150
+ python -m train_procgen.enjoy_sppo --env_name ninja --mode 0 --num_embeddings 8
151
+ ```
152
+
153
+ **Note:** These videos showcase the distinct behavioral patterns learned by our semantic clustering approach. Each cluster demonstrates different combat strategies, movement patterns, and decision-making processes in the Ninja environment.
154
+
155
+ ## 📁 Output Structure
156
+
157
+ ```
158
+ baseline/ # Required RL training package
159
+ train_procgen/
160
+ ├── checkpoints/ # Trained model checkpoints
161
+ ├── figures/ # Generated visualizations and videos
162
+ videos/ # video clips corresponding to the clusters in the paper
163
+ ```
164
+
165
+ ## 📈 Reproducing Results
166
+
167
+ To reproduce the results from our paper:
168
+
169
+ 1. **(Optional) Use existing checkpoints**: We have provided pre-trained checkpoints for Ninja, FruitBot, and Jumper (random seed 2021) in this repository under `train_procgen/checkpoints/`. You can skip training and directly run the visualization scripts. Otherwise, **train models** using the commands above.
170
+ 2. **Generate visualizations** using the provided scripts
171
+ 3. **Analyze results** using the interactive tools
172
+
173
+ **Note:** Video generation may take 30-60 minutes depending on machine performance, as it ensures comprehensive exploration of all clusters.
174
+
175
+ <!-- ## 🤝 Citation
176
+
177
+ If you use this code in your research, please cite our paper: -->
178
+
179
+ <!-- ```bibtex
180
+ @article{zhang2025enhancing,
181
+ title={Enhancing Interpretability in Deep Reinforcement Learning through Semantic Clustering},
182
+ author={Zhang, Liang and Lieffers, Justin and Pyarelal, Adarsh},
183
+ journal={Advances in Neural Information Processing Systems},
184
+ year={2025}
185
+ }
186
+ ``` -->
187
+
188
+ ## 📄 License
189
+
190
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
191
+
192
+ ## 🙏 Acknowledgments
193
+
194
+ This work builds upon the [OpenAI train-procgen](https://github.com/openai/train-procgen) framework. We thank the original authors for their excellent work on procedural generation for reinforcement learning benchmarking.
baselines/__init__.py ADDED
File without changes
baselines/a2c/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A2C
2
+
3
+ - Original paper: https://arxiv.org/abs/1602.01783
4
+ - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
5
+ - `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options
6
+ - also refer to the repo-wide [README.md](../../README.md#training-models)
7
+
8
+ ## Files
9
+ - `run_atari`: file used to run the algorithm.
10
+ - `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...).
11
+ - `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training)
12
+ - learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
13
+ - `runner.py`: class used to generates a batch of experiences
baselines/a2c/__init__.py ADDED
File without changes
baselines/a2c/a2c.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import functools
3
+ import tensorflow as tf
4
+
5
+ from baselines import logger
6
+
7
+ from baselines.common import set_global_seeds, explained_variance
8
+ from baselines.common import tf_util
9
+ from baselines.common.policies import build_policy
10
+
11
+
12
+ from baselines.a2c.utils import Scheduler, find_trainable_variables
13
+ from baselines.a2c.runner import Runner
14
+ from baselines.ppo2.ppo2 import safemean
15
+ from collections import deque
16
+
17
+ from tensorflow import losses
18
+
19
+ class Model(object):
20
+
21
+ """
22
+ We use this class to :
23
+ __init__:
24
+ - Creates the step_model
25
+ - Creates the train_model
26
+
27
+ train():
28
+ - Make the training part (feedforward and retropropagation of gradients)
29
+
30
+ save/load():
31
+ - Save load the model
32
+ """
33
+ def __init__(self, policy, env, nsteps,
34
+ ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
35
+ alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
36
+
37
+ sess = tf_util.get_session()
38
+ nenvs = env.num_envs
39
+ nbatch = nenvs*nsteps
40
+
41
+
42
+ with tf.compat.v1.variable_scope('a2c_model', reuse=tf.compat.v1.AUTO_REUSE):
43
+ # step_model is used for sampling
44
+ step_model = policy(nenvs, 1, sess)
45
+
46
+ # train_model is used to train our network
47
+ train_model = policy(nbatch, nsteps, sess)
48
+
49
+ A = tf.compat.v1.placeholder(train_model.action.dtype, train_model.action.shape)
50
+ ADV = tf.compat.v1.placeholder(tf.float32, [nbatch])
51
+ R = tf.compat.v1.placeholder(tf.float32, [nbatch])
52
+ LR = tf.compat.v1.placeholder(tf.float32, [])
53
+
54
+ # Calculate the loss
55
+ # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
56
+
57
+ # Policy loss
58
+ neglogpac = train_model.pd.neglogp(A)
59
+ # L = A(s,a) * -logpi(a|s)
60
+ pg_loss = tf.reduce_mean(input_tensor=ADV * neglogpac)
61
+
62
+ # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
63
+ entropy = tf.reduce_mean(input_tensor=train_model.pd.entropy())
64
+
65
+ # Value loss
66
+ vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
67
+
68
+ loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
69
+
70
+ # Update parameters using loss
71
+ # 1. Get the model parameters
72
+ params = find_trainable_variables("a2c_model")
73
+
74
+ # 2. Calculate the gradients
75
+ grads = tf.gradients(ys=loss, xs=params)
76
+ if max_grad_norm is not None:
77
+ # Clip the gradients (normalize)
78
+ grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
79
+ grads = list(zip(grads, params))
80
+ # zip aggregate each gradient with parameters associated
81
+ # For instance zip(ABCD, xyza) => Ax, By, Cz, Da
82
+
83
+ # 3. Make op for one policy and value update step of A2C
84
+ trainer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
85
+
86
+ _train = trainer.apply_gradients(grads)
87
+
88
+ lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
89
+
90
+ def train(obs, states, rewards, masks, actions, values):
91
+ # Here we calculate advantage A(s,a) = R + yV(s') - V(s)
92
+ # rewards = R + yV(s')
93
+ advs = rewards - values
94
+ for step in range(len(obs)):
95
+ cur_lr = lr.value()
96
+
97
+ td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
98
+ if states is not None:
99
+ td_map[train_model.S] = states
100
+ td_map[train_model.M] = masks
101
+ policy_loss, value_loss, policy_entropy, _ = sess.run(
102
+ [pg_loss, vf_loss, entropy, _train],
103
+ td_map
104
+ )
105
+ return policy_loss, value_loss, policy_entropy
106
+
107
+
108
+ self.train = train
109
+ self.train_model = train_model
110
+ self.step_model = step_model
111
+ self.step = step_model.step
112
+ self.value = step_model.value
113
+ self.initial_state = step_model.initial_state
114
+ self.save = functools.partial(tf_util.save_variables, sess=sess)
115
+ self.load = functools.partial(tf_util.load_variables, sess=sess)
116
+ tf.compat.v1.global_variables_initializer().run(session=sess)
117
+
118
+
119
+ def learn(
120
+ network,
121
+ env,
122
+ seed=None,
123
+ nsteps=5,
124
+ total_timesteps=int(80e6),
125
+ vf_coef=0.5,
126
+ ent_coef=0.01,
127
+ max_grad_norm=0.5,
128
+ lr=7e-4,
129
+ lrschedule='linear',
130
+ epsilon=1e-5,
131
+ alpha=0.99,
132
+ gamma=0.99,
133
+ log_interval=100,
134
+ load_path=None,
135
+ **network_kwargs):
136
+
137
+ '''
138
+ Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
139
+
140
+ Parameters:
141
+ -----------
142
+
143
+ network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
144
+ specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
145
+ tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
146
+ neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
147
+ See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
148
+
149
+
150
+ env: RL environment. Should implement interface similar to VecEnv (baselines.common/vec_env) or be wrapped with DummyVecEnv (baselines.common/vec_env/dummy_vec_env.py)
151
+
152
+
153
+ seed: seed to make random number sequence in the alorightm reproducible. By default is None which means seed from system noise generator (not reproducible)
154
+
155
+ nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
156
+ nenv is number of environment copies simulated in parallel)
157
+
158
+ total_timesteps: int, total number of timesteps to train on (default: 80M)
159
+
160
+ vf_coef: float, coefficient in front of value function loss in the total loss function (default: 0.5)
161
+
162
+ ent_coef: float, coeffictiant in front of the policy entropy in the total loss function (default: 0.01)
163
+
164
+ max_gradient_norm: float, gradient is clipped to have global L2 norm no more than this value (default: 0.5)
165
+
166
+ lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
167
+
168
+ lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
169
+ returns fraction of the learning rate (specified as lr) as output
170
+
171
+ epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
172
+
173
+ alpha: float, RMSProp decay parameter (default: 0.99)
174
+
175
+ gamma: float, reward discounting parameter (default: 0.99)
176
+
177
+ log_interval: int, specifies how frequently the logs are printed out (default: 100)
178
+
179
+ **network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
180
+ For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
181
+
182
+ '''
183
+
184
+
185
+
186
+ set_global_seeds(seed)
187
+
188
+ # Get the nb of env
189
+ nenvs = env.num_envs
190
+ policy = build_policy(env, network, **network_kwargs)
191
+
192
+ # Instantiate the model object (that creates step_model and train_model)
193
+ model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
194
+ max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
195
+ if load_path is not None:
196
+ model.load(load_path)
197
+
198
+ # Instantiate the runner object
199
+ runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
200
+ epinfobuf = deque(maxlen=100)
201
+
202
+ # Calculate the batch_size
203
+ nbatch = nenvs*nsteps
204
+
205
+ # Start total timer
206
+ tstart = time.time()
207
+
208
+ for update in range(1, total_timesteps//nbatch+1):
209
+ # Get mini batch of experiences
210
+ obs, states, rewards, masks, actions, values, epinfos = runner.run()
211
+ epinfobuf.extend(epinfos)
212
+
213
+ policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
214
+ nseconds = time.time()-tstart
215
+
216
+ # Calculate the fps (frame per second)
217
+ fps = int((update*nbatch)/nseconds)
218
+ if update % log_interval == 0 or update == 1:
219
+ # Calculates if value function is a good predicator of the returns (ev > 1)
220
+ # or if it's just worse than predicting nothing (ev =< 0)
221
+ ev = explained_variance(values, rewards)
222
+ logger.record_tabular("nupdates", update)
223
+ logger.record_tabular("total_timesteps", update*nbatch)
224
+ logger.record_tabular("fps", fps)
225
+ logger.record_tabular("policy_entropy", float(policy_entropy))
226
+ logger.record_tabular("value_loss", float(value_loss))
227
+ logger.record_tabular("explained_variance", float(ev))
228
+ logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
229
+ logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
230
+ logger.dump_tabular()
231
+ return model
232
+
baselines/a2c/runner.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from baselines.a2c.utils import discount_with_dones
3
+ from baselines.common.runners import AbstractEnvRunner
4
+
5
+ class Runner(AbstractEnvRunner):
6
+ """
7
+ We use this class to generate batches of experiences
8
+
9
+ __init__:
10
+ - Initialize the runner
11
+
12
+ run():
13
+ - Make a mini batch of experiences
14
+ """
15
+ def __init__(self, env, model, nsteps=5, gamma=0.99):
16
+ super().__init__(env=env, model=model, nsteps=nsteps)
17
+ self.gamma = gamma
18
+ self.batch_action_shape = [x if x is not None else -1 for x in model.train_model.action.shape.as_list()]
19
+ self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
20
+
21
+ def run(self):
22
+ # We initialize the lists that will contain the mb of experiences
23
+ mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
24
+ mb_states = self.states
25
+ epinfos = []
26
+ for n in range(self.nsteps):
27
+ # Given observations, take action and value (V(s))
28
+ # We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
29
+ actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)
30
+
31
+ # Append the experiences
32
+ mb_obs.append(np.copy(self.obs))
33
+ mb_actions.append(actions)
34
+ mb_values.append(values)
35
+ mb_dones.append(self.dones)
36
+
37
+ # Take actions in env and look the results
38
+ obs, rewards, dones, infos = self.env.step(actions)
39
+ for info in infos:
40
+ maybeepinfo = info.get('episode')
41
+ if maybeepinfo: epinfos.append(maybeepinfo)
42
+ self.states = states
43
+ self.dones = dones
44
+ self.obs = obs
45
+ mb_rewards.append(rewards)
46
+ mb_dones.append(self.dones)
47
+
48
+ # Batch of steps to batch of rollouts
49
+ mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
50
+ mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
51
+ mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
52
+ mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
53
+ mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
54
+ mb_masks = mb_dones[:, :-1]
55
+ mb_dones = mb_dones[:, 1:]
56
+
57
+
58
+ if self.gamma > 0.0:
59
+ # Discount/bootstrap off value fn
60
+ last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
61
+ for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
62
+ rewards = rewards.tolist()
63
+ dones = dones.tolist()
64
+ if dones[-1] == 0:
65
+ rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
66
+ else:
67
+ rewards = discount_with_dones(rewards, dones, self.gamma)
68
+
69
+ mb_rewards[n] = rewards
70
+
71
+ mb_actions = mb_actions.reshape(self.batch_action_shape)
72
+
73
+ mb_rewards = mb_rewards.flatten()
74
+ mb_values = mb_values.flatten()
75
+ mb_masks = mb_masks.flatten()
76
+ return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
baselines/a2c/utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from collections import deque
5
+
6
+ def sample(logits):
7
+ noise = tf.random.uniform(tf.shape(input=logits))
8
+ return tf.argmax(input=logits - tf.math.log(-tf.math.log(noise)), axis=1)
9
+
10
+ def cat_entropy(logits):
11
+ a0 = logits - tf.reduce_max(input_tensor=logits, axis=1, keepdims=True)
12
+ ea0 = tf.exp(a0)
13
+ z0 = tf.reduce_sum(input_tensor=ea0, axis=1, keepdims=True)
14
+ p0 = ea0 / z0
15
+ return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=1)
16
+
17
+ def cat_entropy_softmax(p0):
18
+ return - tf.reduce_sum(input_tensor=p0 * tf.math.log(p0 + 1e-6), axis = 1)
19
+
20
+ def ortho_init(scale=1.0):
21
+ def _ortho_init(shape, dtype, partition_info=None):
22
+ #lasagne ortho init for tf
23
+ shape = tuple(shape)
24
+ if len(shape) == 2:
25
+ flat_shape = shape
26
+ elif len(shape) == 4: # assumes NHWC
27
+ flat_shape = (np.prod(shape[:-1]), shape[-1])
28
+ else:
29
+ raise NotImplementedError
30
+ a = np.random.normal(0.0, 1.0, flat_shape)
31
+ u, _, v = np.linalg.svd(a, full_matrices=False)
32
+ q = u if u.shape == flat_shape else v # pick the one with the correct shape
33
+ q = q.reshape(shape)
34
+ return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
35
+ return _ortho_init
36
+
37
+ def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
38
+ if data_format == 'NHWC':
39
+ channel_ax = 3
40
+ strides = [1, stride, stride, 1]
41
+ bshape = [1, 1, 1, nf]
42
+ elif data_format == 'NCHW':
43
+ channel_ax = 1
44
+ strides = [1, 1, stride, stride]
45
+ bshape = [1, nf, 1, 1]
46
+ else:
47
+ raise NotImplementedError
48
+ bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
49
+ try:
50
+ nin = x.get_shape()[channel_ax].value
51
+ except:
52
+ nin = x.get_shape()[channel_ax]
53
+ wshape = [rf, rf, nin, nf]
54
+ with tf.compat.v1.variable_scope(scope):
55
+ w = tf.compat.v1.get_variable("w", wshape, initializer=ortho_init(init_scale))
56
+ b = tf.compat.v1.get_variable("b", bias_var_shape, initializer=tf.compat.v1.constant_initializer(0.0))
57
+ if not one_dim_bias and data_format == 'NHWC':
58
+ b = tf.reshape(b, bshape)
59
+ return tf.nn.conv2d(input=x, filters=w, strides=strides, padding=pad, data_format=data_format) + b
60
+
61
+ def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
62
+ with tf.compat.v1.variable_scope(scope):
63
+ try:
64
+ nin = x.get_shape()[1].value
65
+ except:
66
+ nin = x.get_shape()[1]
67
+ w = tf.compat.v1.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
68
+ b = tf.compat.v1.get_variable("b", [nh], initializer=tf.compat.v1.constant_initializer(init_bias))
69
+ return tf.matmul(x, w)+b
70
+
71
+ def batch_to_seq(h, nbatch, nsteps, flat=False):
72
+ if flat:
73
+ h = tf.reshape(h, [nbatch, nsteps])
74
+ else:
75
+ h = tf.reshape(h, [nbatch, nsteps, -1])
76
+ return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
77
+
78
+ def seq_to_batch(h, flat = False):
79
+ shape = h[0].get_shape().as_list()
80
+ if not flat:
81
+ assert(len(shape) > 1)
82
+ nh = h[0].get_shape()[-1].value
83
+ return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
84
+ else:
85
+ return tf.reshape(tf.stack(values=h, axis=1), [-1])
86
+
87
+ def lstm(xs, ms, s, scope, nh, init_scale=1.0):
88
+ nbatch, nin = [v.value for v in xs[0].get_shape()]
89
+ with tf.compat.v1.variable_scope(scope):
90
+ wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
91
+ wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
92
+ b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
93
+
94
+ c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
95
+ for idx, (x, m) in enumerate(zip(xs, ms)):
96
+ c = c*(1-m)
97
+ h = h*(1-m)
98
+ z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
99
+ i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
100
+ i = tf.nn.sigmoid(i)
101
+ f = tf.nn.sigmoid(f)
102
+ o = tf.nn.sigmoid(o)
103
+ u = tf.tanh(u)
104
+ c = f*c + i*u
105
+ h = o*tf.tanh(c)
106
+ xs[idx] = h
107
+ s = tf.concat(axis=1, values=[c, h])
108
+ return xs, s
109
+
110
+ def _ln(x, g, b, e=1e-5, axes=[1]):
111
+ u, s = tf.nn.moments(x=x, axes=axes, keepdims=True)
112
+ x = (x-u)/tf.sqrt(s+e)
113
+ x = x*g+b
114
+ return x
115
+
116
+ def lnlstm(xs, ms, s, scope, nh, init_scale=1.0):
117
+ nbatch, nin = [v.value for v in xs[0].get_shape()]
118
+ with tf.compat.v1.variable_scope(scope):
119
+ wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
120
+ gx = tf.compat.v1.get_variable("gx", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0))
121
+ bx = tf.compat.v1.get_variable("bx", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
122
+
123
+ wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
124
+ gh = tf.compat.v1.get_variable("gh", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0))
125
+ bh = tf.compat.v1.get_variable("bh", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
126
+
127
+ b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
128
+
129
+ gc = tf.compat.v1.get_variable("gc", [nh], initializer=tf.compat.v1.constant_initializer(1.0))
130
+ bc = tf.compat.v1.get_variable("bc", [nh], initializer=tf.compat.v1.constant_initializer(0.0))
131
+
132
+ c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
133
+ for idx, (x, m) in enumerate(zip(xs, ms)):
134
+ c = c*(1-m)
135
+ h = h*(1-m)
136
+ z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b
137
+ i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
138
+ i = tf.nn.sigmoid(i)
139
+ f = tf.nn.sigmoid(f)
140
+ o = tf.nn.sigmoid(o)
141
+ u = tf.tanh(u)
142
+ c = f*c + i*u
143
+ h = o*tf.tanh(_ln(c, gc, bc))
144
+ xs[idx] = h
145
+ s = tf.concat(axis=1, values=[c, h])
146
+ return xs, s
147
+
148
+ def conv_to_fc(x):
149
+ try:
150
+ nh = np.prod([v.value for v in x.get_shape()[1:]])
151
+ except:
152
+ nh = np.prod([v for v in x.get_shape()[1:]])
153
+ x = tf.reshape(x, [-1, nh])
154
+ return x
155
+
156
+ def discount_with_dones(rewards, dones, gamma):
157
+ discounted = []
158
+ r = 0
159
+ for reward, done in zip(rewards[::-1], dones[::-1]):
160
+ r = reward + gamma*r*(1.-done) # fixed off by one bug
161
+ discounted.append(r)
162
+ return discounted[::-1]
163
+
164
+ def find_trainable_variables(key):
165
+ return tf.compat.v1.trainable_variables(key)
166
+
167
+ def make_path(f):
168
+ return os.makedirs(f, exist_ok=True)
169
+
170
+ def constant(p):
171
+ return 1
172
+
173
+ def linear(p):
174
+ return 1-p
175
+
176
+ def middle_drop(p):
177
+ eps = 0.75
178
+ if 1-p<eps:
179
+ return eps*0.1
180
+ return 1-p
181
+
182
+ def double_linear_con(p):
183
+ p *= 2
184
+ eps = 0.125
185
+ if 1-p<eps:
186
+ return eps
187
+ return 1-p
188
+
189
+ def double_middle_drop(p):
190
+ eps1 = 0.75
191
+ eps2 = 0.25
192
+ if 1-p<eps1:
193
+ if 1-p<eps2:
194
+ return eps2*0.5
195
+ return eps1*0.1
196
+ return 1-p
197
+
198
+ schedules = {
199
+ 'linear':linear,
200
+ 'constant':constant,
201
+ 'double_linear_con': double_linear_con,
202
+ 'middle_drop': middle_drop,
203
+ 'double_middle_drop': double_middle_drop
204
+ }
205
+
206
+ class Scheduler(object):
207
+
208
+ def __init__(self, v, nvalues, schedule):
209
+ self.n = 0.
210
+ self.v = v
211
+ self.nvalues = nvalues
212
+ self.schedule = schedules[schedule]
213
+
214
+ def value(self):
215
+ current_value = self.v*self.schedule(self.n/self.nvalues)
216
+ self.n += 1.
217
+ return current_value
218
+
219
+ def value_steps(self, steps):
220
+ return self.v*self.schedule(steps/self.nvalues)
221
+
222
+
223
+ class EpisodeStats:
224
+ def __init__(self, nsteps, nenvs):
225
+ self.episode_rewards = []
226
+ for i in range(nenvs):
227
+ self.episode_rewards.append([])
228
+ self.lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
229
+ self.rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
230
+ self.nsteps = nsteps
231
+ self.nenvs = nenvs
232
+
233
+ def feed(self, rewards, masks):
234
+ rewards = np.reshape(rewards, [self.nenvs, self.nsteps])
235
+ masks = np.reshape(masks, [self.nenvs, self.nsteps])
236
+ for i in range(0, self.nenvs):
237
+ for j in range(0, self.nsteps):
238
+ self.episode_rewards[i].append(rewards[i][j])
239
+ if masks[i][j]:
240
+ l = len(self.episode_rewards[i])
241
+ s = sum(self.episode_rewards[i])
242
+ self.lenbuffer.append(l)
243
+ self.rewbuffer.append(s)
244
+ self.episode_rewards[i] = []
245
+
246
+ def mean_length(self):
247
+ if self.lenbuffer:
248
+ return np.mean(self.lenbuffer)
249
+ else:
250
+ return 0 # on the first params dump, no episodes are finished
251
+
252
+ def mean_reward(self):
253
+ if self.rewbuffer:
254
+ return np.mean(self.rewbuffer)
255
+ else:
256
+ return 0
257
+
258
+
259
+ # For ACER
260
+ def get_by_index(x, idx):
261
+ assert(len(x.get_shape()) == 2)
262
+ assert(len(idx.get_shape()) == 1)
263
+ idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
264
+ y = tf.gather(tf.reshape(x, [-1]), # flatten input
265
+ idx_flattened) # use flattened indices
266
+ return y
267
+
268
+ def check_shape(ts,shapes):
269
+ i = 0
270
+ for (t,shape) in zip(ts,shapes):
271
+ assert t.get_shape().as_list()==shape, "id " + str(i) + " shape " + str(t.get_shape()) + str(shape)
272
+ i += 1
273
+
274
+ def avg_norm(t):
275
+ return tf.reduce_mean(input_tensor=tf.sqrt(tf.reduce_sum(input_tensor=tf.square(t), axis=-1)))
276
+
277
+ def gradient_add(g1, g2, param):
278
+ print([g1, g2, param.name])
279
+ assert (not (g1 is None and g2 is None)), param.name
280
+ if g1 is None:
281
+ return g2
282
+ elif g2 is None:
283
+ return g1
284
+ else:
285
+ return g1 + g2
286
+
287
+ def q_explained_variance(qpred, q):
288
+ _, vary = tf.nn.moments(x=q, axes=[0, 1])
289
+ _, varpred = tf.nn.moments(x=q - qpred, axes=[0, 1])
290
+ check_shape([vary, varpred], [[]] * 2)
291
+ return 1.0 - (varpred / vary)
baselines/acer/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # ACER
2
+
3
+ - Original paper: https://arxiv.org/abs/1611.01224
4
+ - `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
5
+ - also refer to the repo-wide [README.md](../../README.md#training-models)
6
+
baselines/acer/__init__.py ADDED
File without changes
baselines/acer/acer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import functools
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from baselines import logger
6
+
7
+ from baselines.common import set_global_seeds
8
+ from baselines.common.policies import build_policy
9
+ from baselines.common.tf_util import get_session, save_variables, load_variables
10
+ from baselines.common.vec_env.vec_frame_stack import VecFrameStack
11
+
12
+ from baselines.a2c.utils import batch_to_seq, seq_to_batch
13
+ from baselines.a2c.utils import cat_entropy_softmax
14
+ from baselines.a2c.utils import Scheduler, find_trainable_variables
15
+ from baselines.a2c.utils import EpisodeStats
16
+ from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
17
+ from baselines.acer.buffer import Buffer
18
+ from baselines.acer.runner import Runner
19
+
20
+ # remove last step
21
+ def strip(var, nenvs, nsteps, flat = False):
22
+ vars = batch_to_seq(var, nenvs, nsteps + 1, flat)
23
+ return seq_to_batch(vars[:-1], flat)
24
+
25
+ def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma):
26
+ """
27
+ Calculates q_retrace targets
28
+
29
+ :param R: Rewards
30
+ :param D: Dones
31
+ :param q_i: Q values for actions taken
32
+ :param v: V values
33
+ :param rho_i: Importance weight for each action
34
+ :return: Q_retrace values
35
+ """
36
+ rho_bar = batch_to_seq(tf.minimum(1.0, rho_i), nenvs, nsteps, True) # list of len steps, shape [nenvs]
37
+ rs = batch_to_seq(R, nenvs, nsteps, True) # list of len steps, shape [nenvs]
38
+ ds = batch_to_seq(D, nenvs, nsteps, True) # list of len steps, shape [nenvs]
39
+ q_is = batch_to_seq(q_i, nenvs, nsteps, True)
40
+ vs = batch_to_seq(v, nenvs, nsteps + 1, True)
41
+ v_final = vs[-1]
42
+ qret = v_final
43
+ qrets = []
44
+ for i in range(nsteps - 1, -1, -1):
45
+ check_shape([qret, ds[i], rs[i], rho_bar[i], q_is[i], vs[i]], [[nenvs]] * 6)
46
+ qret = rs[i] + gamma * qret * (1.0 - ds[i])
47
+ qrets.append(qret)
48
+ qret = (rho_bar[i] * (qret - q_is[i])) + vs[i]
49
+ qrets = qrets[::-1]
50
+ qret = seq_to_batch(qrets, flat=True)
51
+ return qret
52
+
53
+ # For ACER with PPO clipping instead of trust region
54
+ # def clip(ratio, eps_clip):
55
+ # # assume 0 <= eps_clip <= 1
56
+ # return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
57
+
58
+ class Model(object):
59
+ def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr,
60
+ rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
61
+ c, trust_region, alpha, delta):
62
+
63
+ sess = get_session()
64
+ nact = ac_space.n
65
+ nbatch = nenvs * nsteps
66
+
67
+ A = tf.compat.v1.placeholder(tf.int32, [nbatch]) # actions
68
+ D = tf.compat.v1.placeholder(tf.float32, [nbatch]) # dones
69
+ R = tf.compat.v1.placeholder(tf.float32, [nbatch]) # rewards, not returns
70
+ MU = tf.compat.v1.placeholder(tf.float32, [nbatch, nact]) # mu's
71
+ LR = tf.compat.v1.placeholder(tf.float32, [])
72
+ eps = 1e-6
73
+
74
+ step_ob_placeholder = tf.compat.v1.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape)
75
+ train_ob_placeholder = tf.compat.v1.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
76
+ with tf.compat.v1.variable_scope('acer_model', reuse=tf.compat.v1.AUTO_REUSE):
77
+
78
+ step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess)
79
+ train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
80
+
81
+
82
+ params = find_trainable_variables("acer_model")
83
+ print("Params {}".format(len(params)))
84
+ for var in params:
85
+ print(var)
86
+
87
+ # create polyak averaged model
88
+ ema = tf.train.ExponentialMovingAverage(alpha)
89
+ ema_apply_op = ema.apply(params)
90
+
91
+ def custom_getter(getter, *args, **kwargs):
92
+ v = ema.average(getter(*args, **kwargs))
93
+ print(v.name)
94
+ return v
95
+
96
+ with tf.compat.v1.variable_scope("acer_model", custom_getter=custom_getter, reuse=True):
97
+ polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
98
+
99
+ # Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i
100
+
101
+ # action probability distributions according to train_model, polyak_model and step_model
102
+ # poilcy.pi is probability distribution parameters; to obtain distribution that sums to 1 need to take softmax
103
+ train_model_p = tf.nn.softmax(train_model.pi)
104
+ polyak_model_p = tf.nn.softmax(polyak_model.pi)
105
+ step_model_p = tf.nn.softmax(step_model.pi)
106
+ v = tf.reduce_sum(input_tensor=train_model_p * train_model.q, axis = -1) # shape is [nenvs * (nsteps + 1)]
107
+
108
+ # strip off last step
109
+ f, f_pol, q = map(lambda var: strip(var, nenvs, nsteps), [train_model_p, polyak_model_p, train_model.q])
110
+ # Get pi and q values for actions taken
111
+ f_i = get_by_index(f, A)
112
+ q_i = get_by_index(q, A)
113
+
114
+ # Compute ratios for importance truncation
115
+ rho = f / (MU + eps)
116
+ rho_i = get_by_index(rho, A)
117
+
118
+ # Calculate Q_retrace targets
119
+ qret = q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma)
120
+
121
+ # Calculate losses
122
+ # Entropy
123
+ # entropy = tf.reduce_mean(strip(train_model.pd.entropy(), nenvs, nsteps))
124
+ entropy = tf.reduce_mean(input_tensor=cat_entropy_softmax(f))
125
+
126
+ # Policy Graident loss, with truncated importance sampling & bias correction
127
+ v = strip(v, nenvs, nsteps, True)
128
+ check_shape([qret, v, rho_i, f_i], [[nenvs * nsteps]] * 4)
129
+ check_shape([rho, f, q], [[nenvs * nsteps, nact]] * 2)
130
+
131
+ # Truncated importance sampling
132
+ adv = qret - v
133
+ logf = tf.math.log(f_i + eps)
134
+ gain_f = logf * tf.stop_gradient(adv * tf.minimum(c, rho_i)) # [nenvs * nsteps]
135
+ loss_f = -tf.reduce_mean(input_tensor=gain_f)
136
+
137
+ # Bias correction for the truncation
138
+ adv_bc = (q - tf.reshape(v, [nenvs * nsteps, 1])) # [nenvs * nsteps, nact]
139
+ logf_bc = tf.math.log(f + eps) # / (f_old + eps)
140
+ check_shape([adv_bc, logf_bc], [[nenvs * nsteps, nact]]*2)
141
+ gain_bc = tf.reduce_sum(input_tensor=logf_bc * tf.stop_gradient(adv_bc * tf.nn.relu(1.0 - (c / (rho + eps))) * f), axis = 1) #IMP: This is sum, as expectation wrt f
142
+ loss_bc= -tf.reduce_mean(input_tensor=gain_bc)
143
+
144
+ loss_policy = loss_f + loss_bc
145
+
146
+ # Value/Q function loss, and explained variance
147
+ check_shape([qret, q_i], [[nenvs * nsteps]]*2)
148
+ ev = q_explained_variance(tf.reshape(q_i, [nenvs, nsteps]), tf.reshape(qret, [nenvs, nsteps]))
149
+ loss_q = tf.reduce_mean(input_tensor=tf.square(tf.stop_gradient(qret) - q_i)*0.5)
150
+
151
+ # Net loss
152
+ check_shape([loss_policy, loss_q, entropy], [[]] * 3)
153
+ loss = loss_policy + q_coef * loss_q - ent_coef * entropy
154
+
155
+ if trust_region:
156
+ g = tf.gradients(ys=- (loss_policy - ent_coef * entropy) * nsteps * nenvs, xs=f) #[nenvs * nsteps, nact]
157
+ # k = tf.gradients(KL(f_pol || f), f)
158
+ k = - f_pol / (f + eps) #[nenvs * nsteps, nact] # Directly computed gradient of KL divergence wrt f
159
+ k_dot_g = tf.reduce_sum(input_tensor=k * g, axis=-1)
160
+ adj = tf.maximum(0.0, (tf.reduce_sum(input_tensor=k * g, axis=-1) - delta) / (tf.reduce_sum(input_tensor=tf.square(k), axis=-1) + eps)) #[nenvs * nsteps]
161
+
162
+ # Calculate stats (before doing adjustment) for logging.
163
+ avg_norm_k = avg_norm(k)
164
+ avg_norm_g = avg_norm(g)
165
+ avg_norm_k_dot_g = tf.reduce_mean(input_tensor=tf.abs(k_dot_g))
166
+ avg_norm_adj = tf.reduce_mean(input_tensor=tf.abs(adj))
167
+
168
+ g = g - tf.reshape(adj, [nenvs * nsteps, 1]) * k
169
+ grads_f = -g/(nenvs*nsteps) # These are turst region adjusted gradients wrt f ie statistics of policy pi
170
+ grads_policy = tf.gradients(ys=f, xs=params, grad_ys=grads_f)
171
+ grads_q = tf.gradients(ys=loss_q * q_coef, xs=params)
172
+ grads = [gradient_add(g1, g2, param) for (g1, g2, param) in zip(grads_policy, grads_q, params)]
173
+
174
+ avg_norm_grads_f = avg_norm(grads_f) * (nsteps * nenvs)
175
+ norm_grads_q = tf.linalg.global_norm(grads_q)
176
+ norm_grads_policy = tf.linalg.global_norm(grads_policy)
177
+ else:
178
+ grads = tf.gradients(ys=loss, xs=params)
179
+
180
+ if max_grad_norm is not None:
181
+ grads, norm_grads = tf.clip_by_global_norm(grads, max_grad_norm)
182
+ grads = list(zip(grads, params))
183
+ trainer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=LR, decay=rprop_alpha, epsilon=rprop_epsilon)
184
+ _opt_op = trainer.apply_gradients(grads)
185
+
186
+ # so when you call _train, you first do the gradient step, then you apply ema
187
+ with tf.control_dependencies([_opt_op]):
188
+ _train = tf.group(ema_apply_op)
189
+
190
+ lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
191
+
192
+ # Ops/Summaries to run, and their names for logging
193
+ run_ops = [_train, loss, loss_q, entropy, loss_policy, loss_f, loss_bc, ev, norm_grads]
194
+ names_ops = ['loss', 'loss_q', 'entropy', 'loss_policy', 'loss_f', 'loss_bc', 'explained_variance',
195
+ 'norm_grads']
196
+ if trust_region:
197
+ run_ops = run_ops + [norm_grads_q, norm_grads_policy, avg_norm_grads_f, avg_norm_k, avg_norm_g, avg_norm_k_dot_g,
198
+ avg_norm_adj]
199
+ names_ops = names_ops + ['norm_grads_q', 'norm_grads_policy', 'avg_norm_grads_f', 'avg_norm_k', 'avg_norm_g',
200
+ 'avg_norm_k_dot_g', 'avg_norm_adj']
201
+
202
+ def train(obs, actions, rewards, dones, mus, states, masks, steps):
203
+ cur_lr = lr.value_steps(steps)
204
+ td_map = {train_model.X: obs, polyak_model.X: obs, A: actions, R: rewards, D: dones, MU: mus, LR: cur_lr}
205
+ if states is not None:
206
+ td_map[train_model.S] = states
207
+ td_map[train_model.M] = masks
208
+ td_map[polyak_model.S] = states
209
+ td_map[polyak_model.M] = masks
210
+
211
+ return names_ops, sess.run(run_ops, td_map)[1:] # strip off _train
212
+
213
+ def _step(observation, **kwargs):
214
+ return step_model._evaluate([step_model.action, step_model_p, step_model.state], observation, **kwargs)
215
+
216
+
217
+
218
+ self.train = train
219
+ self.save = functools.partial(save_variables, sess=sess)
220
+ self.load = functools.partial(load_variables, sess=sess)
221
+ self.train_model = train_model
222
+ self.step_model = step_model
223
+ self._step = _step
224
+ self.step = self.step_model.step
225
+
226
+ self.initial_state = step_model.initial_state
227
+ tf.compat.v1.global_variables_initializer().run(session=sess)
228
+
229
+
230
+ class Acer():
231
+ def __init__(self, runner, model, buffer, log_interval):
232
+ self.runner = runner
233
+ self.model = model
234
+ self.buffer = buffer
235
+ self.log_interval = log_interval
236
+ self.tstart = None
237
+ self.episode_stats = EpisodeStats(runner.nsteps, runner.nenv)
238
+ self.steps = None
239
+
240
+ def call(self, on_policy):
241
+ runner, model, buffer, steps = self.runner, self.model, self.buffer, self.steps
242
+ if on_policy:
243
+ enc_obs, obs, actions, rewards, mus, dones, masks = runner.run()
244
+ self.episode_stats.feed(rewards, dones)
245
+ if buffer is not None:
246
+ buffer.put(enc_obs, actions, rewards, mus, dones, masks)
247
+ else:
248
+ # get obs, actions, rewards, mus, dones from buffer.
249
+ obs, actions, rewards, mus, dones, masks = buffer.get()
250
+
251
+
252
+ # reshape stuff correctly
253
+ obs = obs.reshape(runner.batch_ob_shape)
254
+ actions = actions.reshape([runner.nbatch])
255
+ rewards = rewards.reshape([runner.nbatch])
256
+ mus = mus.reshape([runner.nbatch, runner.nact])
257
+ dones = dones.reshape([runner.nbatch])
258
+ masks = masks.reshape([runner.batch_ob_shape[0]])
259
+
260
+ names_ops, values_ops = model.train(obs, actions, rewards, dones, mus, model.initial_state, masks, steps)
261
+
262
+ if on_policy and (int(steps/runner.nbatch) % self.log_interval == 0):
263
+ logger.record_tabular("total_timesteps", steps)
264
+ logger.record_tabular("fps", int(steps/(time.time() - self.tstart)))
265
+ # IMP: In EpisodicLife env, during training, we get done=True at each loss of life, not just at the terminal state.
266
+ # Thus, this is mean until end of life, not end of episode.
267
+ # For true episode rewards, see the monitor files in the log folder.
268
+ logger.record_tabular("mean_episode_length", self.episode_stats.mean_length())
269
+ logger.record_tabular("mean_episode_reward", self.episode_stats.mean_reward())
270
+ for name, val in zip(names_ops, values_ops):
271
+ logger.record_tabular(name, float(val))
272
+ logger.dump_tabular()
273
+
274
+
275
+ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
276
+ max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
277
+ log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
278
+ trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
279
+
280
+ '''
281
+ Main entrypoint for ACER (Actor-Critic with Experience Replay) algorithm (https://arxiv.org/pdf/1611.01224.pdf)
282
+ Train an agent with given network architecture on a given environment using ACER.
283
+
284
+ Parameters:
285
+ ----------
286
+
287
+ network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
288
+ specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
289
+ tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
290
+ neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
291
+ See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
292
+
293
+ env: environment. Needs to be vectorized for parallel environment simulation.
294
+ The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.
295
+
296
+ nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
297
+ nenv is number of environment copies simulated in parallel) (default: 20)
298
+
299
+ nstack: int, size of the frame stack, i.e. number of the frames passed to the step model. Frames are stacked along channel dimension
300
+ (last image dimension) (default: 4)
301
+
302
+ total_timesteps: int, number of timesteps (i.e. number of actions taken in the environment) (default: 80M)
303
+
304
+ q_coef: float, value function loss coefficient in the optimization objective (analog of vf_coef for other actor-critic methods)
305
+
306
+ ent_coef: float, policy entropy coefficient in the optimization objective (default: 0.01)
307
+
308
+ max_grad_norm: float, gradient norm clipping coefficient. If set to None, no clipping. (default: 10),
309
+
310
+ lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
311
+
312
+ lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
313
+ returns fraction of the learning rate (specified as lr) as output
314
+
315
+ rprop_epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
316
+
317
+ rprop_alpha: float, RMSProp decay parameter (default: 0.99)
318
+
319
+ gamma: float, reward discounting factor (default: 0.99)
320
+
321
+ log_interval: int, number of updates between logging events (default: 100)
322
+
323
+ buffer_size: int, size of the replay buffer (default: 50k)
324
+
325
+ replay_ratio: int, now many (on average) batches of data to sample from the replay buffer take after batch from the environment (default: 4)
326
+
327
+ replay_start: int, the sampling from the replay buffer does not start until replay buffer has at least that many samples (default: 10k)
328
+
329
+ c: float, importance weight clipping factor (default: 10)
330
+
331
+ trust_region bool, whether or not algorithms estimates the gradient KL divergence between the old and updated policy and uses it to determine step size (default: True)
332
+
333
+ delta: float, max KL divergence between the old policy and updated policy (default: 1)
334
+
335
+ alpha: float, momentum factor in the Polyak (exponential moving average) averaging of the model parameters (default: 0.99)
336
+
337
+ load_path: str, path to load the model from (default: None)
338
+
339
+ **network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
340
+ For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
341
+
342
+ '''
343
+
344
+ print("Running Acer Simple")
345
+ print(locals())
346
+ set_global_seeds(seed)
347
+ if not isinstance(env, VecFrameStack):
348
+ env = VecFrameStack(env, 1)
349
+
350
+ policy = build_policy(env, network, estimate_q=True, **network_kwargs)
351
+ nenvs = env.num_envs
352
+ ob_space = env.observation_space
353
+ ac_space = env.action_space
354
+
355
+ nstack = env.nstack
356
+ model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps,
357
+ ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
358
+ max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
359
+ total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
360
+ trust_region=trust_region, alpha=alpha, delta=delta)
361
+
362
+ if load_path is not None:
363
+ model.load(load_path)
364
+
365
+ runner = Runner(env=env, model=model, nsteps=nsteps)
366
+ if replay_ratio > 0:
367
+ buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
368
+ else:
369
+ buffer = None
370
+ nbatch = nenvs*nsteps
371
+ acer = Acer(runner, model, buffer, log_interval)
372
+ acer.tstart = time.time()
373
+
374
+ for acer.steps in range(0, total_timesteps, nbatch): #nbatch samples, 1 on_policy call and multiple off-policy calls
375
+ acer.call(on_policy=True)
376
+ if replay_ratio > 0 and buffer.has_atleast(replay_start):
377
+ n = np.random.poisson(replay_ratio)
378
+ for _ in range(n):
379
+ acer.call(on_policy=False) # no simulation steps in this
380
+
381
+ return model
baselines/acer/buffer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Buffer(object):
4
+ # gets obs, actions, rewards, mu's, (states, masks), dones
5
+ def __init__(self, env, nsteps, size=50000):
6
+ self.nenv = env.num_envs
7
+ self.nsteps = nsteps
8
+ # self.nh, self.nw, self.nc = env.observation_space.shape
9
+ self.obs_shape = env.observation_space.shape
10
+ self.obs_dtype = env.observation_space.dtype
11
+ self.ac_dtype = env.action_space.dtype
12
+ self.nc = self.obs_shape[-1]
13
+ self.nstack = env.nstack
14
+ self.nc //= self.nstack
15
+ self.nbatch = self.nenv * self.nsteps
16
+ self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
17
+
18
+ # Memory
19
+ self.enc_obs = None
20
+ self.actions = None
21
+ self.rewards = None
22
+ self.mus = None
23
+ self.dones = None
24
+ self.masks = None
25
+
26
+ # Size indexes
27
+ self.next_idx = 0
28
+ self.num_in_buffer = 0
29
+
30
+ def has_atleast(self, frames):
31
+ # Frames per env, so total (nenv * frames) Frames needed
32
+ # Each buffer loc has nenv * nsteps frames
33
+ return self.num_in_buffer >= (frames // self.nsteps)
34
+
35
+ def can_sample(self):
36
+ return self.num_in_buffer > 0
37
+
38
+ # Generate stacked frames
39
+ def decode(self, enc_obs, dones):
40
+ # enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
41
+ # dones has shape [nenvs, nsteps]
42
+ # returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
43
+
44
+ return _stack_obs(enc_obs, dones,
45
+ nsteps=self.nsteps)
46
+
47
+ def put(self, enc_obs, actions, rewards, mus, dones, masks):
48
+ # enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
49
+ # actions, rewards, dones [nenv, nsteps]
50
+ # mus [nenv, nsteps, nact]
51
+
52
+ if self.enc_obs is None:
53
+ self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
54
+ self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype)
55
+ self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
56
+ self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
57
+ self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
58
+ self.masks = np.empty([self.size] + list(masks.shape), dtype=np.bool)
59
+
60
+ self.enc_obs[self.next_idx] = enc_obs
61
+ self.actions[self.next_idx] = actions
62
+ self.rewards[self.next_idx] = rewards
63
+ self.mus[self.next_idx] = mus
64
+ self.dones[self.next_idx] = dones
65
+ self.masks[self.next_idx] = masks
66
+
67
+ self.next_idx = (self.next_idx + 1) % self.size
68
+ self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
69
+
70
+ def take(self, x, idx, envx):
71
+ nenv = self.nenv
72
+ out = np.empty([nenv] + list(x.shape[2:]), dtype=x.dtype)
73
+ for i in range(nenv):
74
+ out[i] = x[idx[i], envx[i]]
75
+ return out
76
+
77
+ def get(self):
78
+ # returns
79
+ # obs [nenv, (nsteps + 1), nh, nw, nstack*nc]
80
+ # actions, rewards, dones [nenv, nsteps]
81
+ # mus [nenv, nsteps, nact]
82
+ nenv = self.nenv
83
+ assert self.can_sample()
84
+
85
+ # Sample exactly one id per env. If you sample across envs, then higher correlation in samples from same env.
86
+ idx = np.random.randint(0, self.num_in_buffer, nenv)
87
+ envx = np.arange(nenv)
88
+
89
+ take = lambda x: self.take(x, idx, envx) # for i in range(nenv)], axis = 0)
90
+ dones = take(self.dones)
91
+ enc_obs = take(self.enc_obs)
92
+ obs = self.decode(enc_obs, dones)
93
+ actions = take(self.actions)
94
+ rewards = take(self.rewards)
95
+ mus = take(self.mus)
96
+ masks = take(self.masks)
97
+ return obs, actions, rewards, mus, dones, masks
98
+
99
+
100
+
101
+ def _stack_obs_ref(enc_obs, dones, nsteps):
102
+ nenv = enc_obs.shape[0]
103
+ nstack = enc_obs.shape[1] - nsteps
104
+ nh, nw, nc = enc_obs.shape[2:]
105
+ obs_dtype = enc_obs.dtype
106
+ obs_shape = (nh, nw, nc*nstack)
107
+
108
+ mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
109
+ obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype)
110
+ x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc]
111
+
112
+ mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
113
+ mask[:nstack-1] = 1.0
114
+
115
+ # y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
116
+ for i in range(nstack):
117
+ obs[-(i + 1), i:] = x
118
+ # obs[:,i:,:,:,-(i+1),:] = x
119
+ x = x[:-1] * mask
120
+ mask = mask[1:]
121
+
122
+ return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape)
123
+
124
+ def _stack_obs(enc_obs, dones, nsteps):
125
+ nenv = enc_obs.shape[0]
126
+ nstack = enc_obs.shape[1] - nsteps
127
+ nc = enc_obs.shape[-1]
128
+
129
+ obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype)
130
+ mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype)
131
+ mask[:, 1:] = 1.0 - dones
132
+ mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8)))
133
+
134
+ for i in range(nstack-1, -1, -1):
135
+ obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :]
136
+ if i < nstack-1:
137
+ obs_[..., i * nc : (i + 1) * nc] *= mask
138
+ mask[:, 1:, ...] *= mask[:, :-1, ...]
139
+
140
+ return obs_
141
+
142
+ def test_stack_obs():
143
+ nstack = 7
144
+ nenv = 1
145
+ nsteps = 5
146
+
147
+ obs_shape = (2, 3, nstack)
148
+
149
+ enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,)
150
+ enc_obs = np.random.random(enc_obs_shape)
151
+ dones = np.random.randint(low=0, high=2, size=(nenv, nsteps))
152
+
153
+ stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps)
154
+ stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps)
155
+
156
+ np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test)
baselines/acer/defaults.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def atari():
2
+ return dict(
3
+ lrschedule='constant'
4
+ )
baselines/acer/policies.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from baselines.common.policies import nature_cnn
4
+ from baselines.a2c.utils import fc, batch_to_seq, seq_to_batch, lstm, sample
5
+
6
+
7
+ class AcerCnnPolicy(object):
8
+
9
+ def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False):
10
+ nbatch = nenv * nsteps
11
+ nh, nw, nc = ob_space.shape
12
+ ob_shape = (nbatch, nh, nw, nc * nstack)
13
+ nact = ac_space.n
14
+ X = tf.compat.v1.placeholder(tf.uint8, ob_shape) # obs
15
+ with tf.compat.v1.variable_scope("model", reuse=reuse):
16
+ h = nature_cnn(X)
17
+ pi_logits = fc(h, 'pi', nact, init_scale=0.01)
18
+ pi = tf.nn.softmax(pi_logits)
19
+ q = fc(h, 'q', nact)
20
+
21
+ a = sample(tf.nn.softmax(pi_logits)) # could change this to use self.pi instead
22
+ self.initial_state = [] # not stateful
23
+ self.X = X
24
+ self.pi = pi # actual policy params now
25
+ self.pi_logits = pi_logits
26
+ self.q = q
27
+ self.vf = q
28
+
29
+ def step(ob, *args, **kwargs):
30
+ # returns actions, mus, states
31
+ a0, pi0 = sess.run([a, pi], {X: ob})
32
+ return a0, pi0, [] # dummy state
33
+
34
+ def out(ob, *args, **kwargs):
35
+ pi0, q0 = sess.run([pi, q], {X: ob})
36
+ return pi0, q0
37
+
38
+ def act(ob, *args, **kwargs):
39
+ return sess.run(a, {X: ob})
40
+
41
+ self.step = step
42
+ self.out = out
43
+ self.act = act
44
+
45
+ class AcerLstmPolicy(object):
46
+
47
+ def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False, nlstm=256):
48
+ nbatch = nenv * nsteps
49
+ nh, nw, nc = ob_space.shape
50
+ ob_shape = (nbatch, nh, nw, nc * nstack)
51
+ nact = ac_space.n
52
+ X = tf.compat.v1.placeholder(tf.uint8, ob_shape) # obs
53
+ M = tf.compat.v1.placeholder(tf.float32, [nbatch]) #mask (done t-1)
54
+ S = tf.compat.v1.placeholder(tf.float32, [nenv, nlstm*2]) #states
55
+ with tf.compat.v1.variable_scope("model", reuse=reuse):
56
+ h = nature_cnn(X)
57
+
58
+ # lstm
59
+ xs = batch_to_seq(h, nenv, nsteps)
60
+ ms = batch_to_seq(M, nenv, nsteps)
61
+ h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
62
+ h5 = seq_to_batch(h5)
63
+
64
+ pi_logits = fc(h5, 'pi', nact, init_scale=0.01)
65
+ pi = tf.nn.softmax(pi_logits)
66
+ q = fc(h5, 'q', nact)
67
+
68
+ a = sample(pi_logits) # could change this to use self.pi instead
69
+ self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
70
+ self.X = X
71
+ self.M = M
72
+ self.S = S
73
+ self.pi = pi # actual policy params now
74
+ self.q = q
75
+
76
+ def step(ob, state, mask, *args, **kwargs):
77
+ # returns actions, mus, states
78
+ a0, pi0, s = sess.run([a, pi, snew], {X: ob, S: state, M: mask})
79
+ return a0, pi0, s
80
+
81
+ self.step = step
baselines/acer/runner.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from baselines.common.runners import AbstractEnvRunner
3
+ from baselines.common.vec_env.vec_frame_stack import VecFrameStack
4
+ from gym import spaces
5
+
6
+
7
+ class Runner(AbstractEnvRunner):
8
+
9
+ def __init__(self, env, model, nsteps):
10
+ super().__init__(env=env, model=model, nsteps=nsteps)
11
+ assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
12
+ assert isinstance(env, VecFrameStack)
13
+
14
+ self.nact = env.action_space.n
15
+ nenv = self.nenv
16
+ self.nbatch = nenv * nsteps
17
+ self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape
18
+
19
+ self.obs = env.reset()
20
+ self.obs_dtype = env.observation_space.dtype
21
+ self.ac_dtype = env.action_space.dtype
22
+ self.nstack = self.env.nstack
23
+ self.nc = self.batch_ob_shape[-1] // self.nstack
24
+
25
+
26
+ def run(self):
27
+ # enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
28
+ enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
29
+ mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
30
+ for _ in range(self.nsteps):
31
+ actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
32
+ mb_obs.append(np.copy(self.obs))
33
+ mb_actions.append(actions)
34
+ mb_mus.append(mus)
35
+ mb_dones.append(self.dones)
36
+ obs, rewards, dones, _ = self.env.step(actions)
37
+ # states information for statefull models like LSTM
38
+ self.states = states
39
+ self.dones = dones
40
+ self.obs = obs
41
+ mb_rewards.append(rewards)
42
+ enc_obs.append(obs[..., -self.nc:])
43
+ mb_obs.append(np.copy(self.obs))
44
+ mb_dones.append(self.dones)
45
+
46
+ enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
47
+ mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
48
+ mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)
49
+ mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
50
+ mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
51
+
52
+ mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
53
+
54
+ mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
55
+ mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards
56
+
57
+ # shapes are now [nenv, nsteps, []]
58
+ # When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.
59
+
60
+ return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks
61
+
baselines/acktr/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ACKTR
2
+
3
+ - Original paper: https://arxiv.org/abs/1708.05144
4
+ - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
5
+ - `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
6
+ - also refer to the repo-wide [README.md](../../README.md#training-models)
7
+
8
+ ## ACKTR with continuous action spaces
9
+ The code of ACKTR has been refactored to handle both discrete and continuous action spaces uniformly. In the original version, discrete and continuous action spaces were handled by different code (actkr_disc.py and acktr_cont.py) with little overlap. If interested in the original version of the acktr for continuous action spaces, use `old_acktr_cont` branch. Note that original code performs better on the mujoco tasks than the refactored version; we are still investigating why.
baselines/acktr/__init__.py ADDED
File without changes
baselines/acktr/acktr.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import time
3
+ import functools
4
+ import tensorflow as tf
5
+ from baselines import logger
6
+
7
+ from baselines.common import set_global_seeds, explained_variance
8
+ from baselines.common.policies import build_policy
9
+ from baselines.common.tf_util import get_session, save_variables, load_variables
10
+
11
+ from baselines.a2c.runner import Runner
12
+ from baselines.a2c.utils import Scheduler, find_trainable_variables
13
+ from baselines.acktr import kfac
14
+ from baselines.ppo2.ppo2 import safemean
15
+ from collections import deque
16
+
17
+
18
+ class Model(object):
19
+
20
+ def __init__(self, policy, ob_space, ac_space, nenvs,total_timesteps, nprocs=32, nsteps=20,
21
+ ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
22
+ kfac_clip=0.001, lrschedule='linear', is_async=True):
23
+
24
+ self.sess = sess = get_session()
25
+ nbatch = nenvs * nsteps
26
+ with tf.compat.v1.variable_scope('acktr_model', reuse=tf.compat.v1.AUTO_REUSE):
27
+ self.model = step_model = policy(nenvs, 1, sess=sess)
28
+ self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
29
+
30
+ A = train_model.pdtype.sample_placeholder([None])
31
+ ADV = tf.compat.v1.placeholder(tf.float32, [nbatch])
32
+ R = tf.compat.v1.placeholder(tf.float32, [nbatch])
33
+ PG_LR = tf.compat.v1.placeholder(tf.float32, [])
34
+ VF_LR = tf.compat.v1.placeholder(tf.float32, [])
35
+
36
+ neglogpac = train_model.pd.neglogp(A)
37
+ self.logits = train_model.pi
38
+
39
+ ##training loss
40
+ pg_loss = tf.reduce_mean(input_tensor=ADV*neglogpac)
41
+ entropy = tf.reduce_mean(input_tensor=train_model.pd.entropy())
42
+ pg_loss = pg_loss - ent_coef * entropy
43
+ vf_loss = tf.compat.v1.losses.mean_squared_error(tf.squeeze(train_model.vf), R)
44
+ train_loss = pg_loss + vf_coef * vf_loss
45
+
46
+
47
+ ##Fisher loss construction
48
+ self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(input_tensor=neglogpac)
49
+ sample_net = train_model.vf + tf.random.normal(tf.shape(input=train_model.vf))
50
+ self.vf_fisher = vf_fisher_loss = - vf_fisher_coef*tf.reduce_mean(input_tensor=tf.pow(train_model.vf - tf.stop_gradient(sample_net), 2))
51
+ self.joint_fisher = joint_fisher_loss = pg_fisher_loss + vf_fisher_loss
52
+
53
+ self.params=params = find_trainable_variables("acktr_model")
54
+
55
+ self.grads_check = grads = tf.gradients(ys=train_loss,xs=params)
56
+
57
+ with tf.device('/gpu:0'):
58
+ self.optim = optim = kfac.KfacOptimizer(learning_rate=PG_LR, clip_kl=kfac_clip,\
59
+ momentum=0.9, kfac_update=1, epsilon=0.01,\
60
+ stats_decay=0.99, is_async=is_async, cold_iter=10, max_grad_norm=max_grad_norm)
61
+
62
+ # update_stats_op = optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
63
+ optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
64
+ train_op, q_runner = optim.apply_gradients(list(zip(grads,params)))
65
+ self.q_runner = q_runner
66
+ self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
67
+
68
+ def train(obs, states, rewards, masks, actions, values):
69
+ advs = rewards - values
70
+ for step in range(len(obs)):
71
+ cur_lr = self.lr.value()
72
+
73
+ td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, PG_LR:cur_lr, VF_LR:cur_lr}
74
+ if states is not None:
75
+ td_map[train_model.S] = states
76
+ td_map[train_model.M] = masks
77
+
78
+ policy_loss, value_loss, policy_entropy, _ = sess.run(
79
+ [pg_loss, vf_loss, entropy, train_op],
80
+ td_map
81
+ )
82
+ return policy_loss, value_loss, policy_entropy
83
+
84
+
85
+ self.train = train
86
+ self.save = functools.partial(save_variables, sess=sess)
87
+ self.load = functools.partial(load_variables, sess=sess)
88
+ self.train_model = train_model
89
+ self.step_model = step_model
90
+ self.step = step_model.step
91
+ self.value = step_model.value
92
+ self.initial_state = step_model.initial_state
93
+ tf.compat.v1.global_variables_initializer().run(session=sess)
94
+
95
+ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=100, nprocs=32, nsteps=20,
96
+ ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
97
+ kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
98
+ set_global_seeds(seed)
99
+
100
+
101
+ if network == 'cnn':
102
+ network_kwargs['one_dim_bias'] = True
103
+
104
+ policy = build_policy(env, network, **network_kwargs)
105
+
106
+ nenvs = env.num_envs
107
+ ob_space = env.observation_space
108
+ ac_space = env.action_space
109
+ make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps
110
+ =nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef=
111
+ vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip,
112
+ lrschedule=lrschedule, is_async=is_async)
113
+ if save_interval and logger.get_dir():
114
+ import cloudpickle
115
+ with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
116
+ fh.write(cloudpickle.dumps(make_model))
117
+ model = make_model()
118
+
119
+ if load_path is not None:
120
+ model.load(load_path)
121
+
122
+ runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
123
+ epinfobuf = deque(maxlen=100)
124
+ nbatch = nenvs*nsteps
125
+ tstart = time.time()
126
+ coord = tf.train.Coordinator()
127
+ if is_async:
128
+ enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True)
129
+ else:
130
+ enqueue_threads = []
131
+
132
+ for update in range(1, total_timesteps//nbatch+1):
133
+ obs, states, rewards, masks, actions, values, epinfos = runner.run()
134
+ epinfobuf.extend(epinfos)
135
+ policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
136
+ model.old_obs = obs
137
+ nseconds = time.time()-tstart
138
+ fps = int((update*nbatch)/nseconds)
139
+ if update % log_interval == 0 or update == 1:
140
+ ev = explained_variance(values, rewards)
141
+ logger.record_tabular("nupdates", update)
142
+ logger.record_tabular("total_timesteps", update*nbatch)
143
+ logger.record_tabular("fps", fps)
144
+ logger.record_tabular("policy_entropy", float(policy_entropy))
145
+ logger.record_tabular("policy_loss", float(policy_loss))
146
+ logger.record_tabular("value_loss", float(value_loss))
147
+ logger.record_tabular("explained_variance", float(ev))
148
+ logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
149
+ logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
150
+ logger.dump_tabular()
151
+
152
+ if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
153
+ savepath = osp.join(logger.get_dir(), 'checkpoint%.5i'%update)
154
+ print('Saving to', savepath)
155
+ model.save(savepath)
156
+ coord.request_stop()
157
+ coord.join(enqueue_threads)
158
+ return model
baselines/acktr/defaults.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def mujoco():
2
+ return dict(
3
+ nsteps=2500,
4
+ value_network='copy'
5
+ )
baselines/acktr/kfac.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import re
4
+
5
+ # flake8: noqa F403, F405
6
+ from baselines.acktr.kfac_utils import *
7
+ from functools import reduce
8
+
9
+ KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd']
10
+ KFAC_DEBUG = False
11
+
12
+
13
+ class KfacOptimizer():
14
+ # note that KfacOptimizer will be truly synchronous (and thus deterministic) only if a single-threaded session is used
15
+ def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, is_async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
16
+ self.max_grad_norm = max_grad_norm
17
+ self._lr = learning_rate
18
+ self._momentum = momentum
19
+ self._clip_kl = clip_kl
20
+ self._channel_fac = channel_fac
21
+ self._kfac_update = kfac_update
22
+ self._async = is_async
23
+ self._async_stats = async_stats
24
+ self._epsilon = epsilon
25
+ self._stats_decay = stats_decay
26
+ self._blockdiag_bias = blockdiag_bias
27
+ self._approxT2 = approxT2
28
+ self._use_float64 = use_float64
29
+ self._factored_damping = factored_damping
30
+ self._cold_iter = cold_iter
31
+ if cold_lr == None:
32
+ # good heuristics
33
+ self._cold_lr = self._lr# * 3.
34
+ else:
35
+ self._cold_lr = cold_lr
36
+ self._stats_accum_iter = stats_accum_iter
37
+ self._weight_decay_dict = weight_decay_dict
38
+ self._diag_init_coeff = 0.
39
+ self._full_stats_init = full_stats_init
40
+ if not self._full_stats_init:
41
+ self._stats_accum_iter = self._cold_iter
42
+
43
+ self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
44
+ self.global_step = tf.Variable(
45
+ 0, name='KFAC/global_step', trainable=False)
46
+ self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
47
+ self.factor_step = tf.Variable(
48
+ 0, name='KFAC/factor_step', trainable=False)
49
+ self.stats_step = tf.Variable(
50
+ 0, name='KFAC/stats_step', trainable=False)
51
+ self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)
52
+
53
+ self.factors = {}
54
+ self.param_vars = []
55
+ self.stats = {}
56
+ self.stats_eigen = {}
57
+
58
+ def getFactors(self, g, varlist):
59
+ graph = tf.compat.v1.get_default_graph()
60
+ factorTensors = {}
61
+ fpropTensors = []
62
+ bpropTensors = []
63
+ opTypes = []
64
+ fops = []
65
+
66
+ def searchFactors(gradient, graph):
67
+ # hard coded search stratergy
68
+ bpropOp = gradient.op
69
+ bpropOp_name = bpropOp.name
70
+
71
+ bTensors = []
72
+ fTensors = []
73
+
74
+ # combining additive gradient, assume they are the same op type and
75
+ # indepedent
76
+ if 'AddN' in bpropOp_name:
77
+ factors = []
78
+ for g in gradient.op.inputs:
79
+ factors.append(searchFactors(g, graph))
80
+ op_names = [item['opName'] for item in factors]
81
+ # TO-DO: need to check all the attribute of the ops as well
82
+ print (gradient.name)
83
+ print (op_names)
84
+ print (len(np.unique(op_names)))
85
+ assert len(np.unique(op_names)) == 1, gradient.name + \
86
+ ' is shared among different computation OPs'
87
+
88
+ bTensors = reduce(lambda x, y: x + y,
89
+ [item['bpropFactors'] for item in factors])
90
+ if len(factors[0]['fpropFactors']) > 0:
91
+ fTensors = reduce(
92
+ lambda x, y: x + y, [item['fpropFactors'] for item in factors])
93
+ fpropOp_name = op_names[0]
94
+ fpropOp = factors[0]['op']
95
+ else:
96
+ fpropOp_name = re.search(
97
+ 'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
98
+ fpropOp = graph.get_operation_by_name(fpropOp_name)
99
+ if fpropOp.op_def.name in KFAC_OPS:
100
+ # Known OPs
101
+ ###
102
+ bTensor = [
103
+ i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
104
+ bTensorShape = fpropOp.outputs[0].get_shape()
105
+ if bTensor.get_shape()[0].value == None:
106
+ bTensor.set_shape(bTensorShape)
107
+ bTensors.append(bTensor)
108
+ ###
109
+ if fpropOp.op_def.name == 'BiasAdd':
110
+ fTensors = []
111
+ else:
112
+ fTensors.append(
113
+ [i for i in fpropOp.inputs if param.op.name not in i.name][0])
114
+ fpropOp_name = fpropOp.op_def.name
115
+ else:
116
+ # unknown OPs, block approximation used
117
+ bInputsList = [i for i in bpropOp.inputs[
118
+ 0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
119
+ if len(bInputsList) > 0:
120
+ bTensor = bInputsList[0]
121
+ bTensorShape = fpropOp.outputs[0].get_shape()
122
+ if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
123
+ bTensor.set_shape(bTensorShape)
124
+ bTensors.append(bTensor)
125
+ fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)
126
+
127
+ return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}
128
+
129
+ for t, param in zip(g, varlist):
130
+ if KFAC_DEBUG:
131
+ print(('get factor for '+param.name))
132
+ factors = searchFactors(t, graph)
133
+ factorTensors[param] = factors
134
+
135
+ ########
136
+ # check associated weights and bias for homogeneous coordinate representation
137
+ # and check redundent factors
138
+ # TO-DO: there may be a bug to detect associate bias and weights for
139
+ # forking layer, e.g. in inception models.
140
+ for param in varlist:
141
+ factorTensors[param]['assnWeights'] = None
142
+ factorTensors[param]['assnBias'] = None
143
+ for param in varlist:
144
+ if factorTensors[param]['opName'] == 'BiasAdd':
145
+ factorTensors[param]['assnWeights'] = None
146
+ for item in varlist:
147
+ if len(factorTensors[item]['bpropFactors']) > 0:
148
+ if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
149
+ factorTensors[param]['assnWeights'] = item
150
+ factorTensors[item]['assnBias'] = param
151
+ factorTensors[param]['bpropFactors'] = factorTensors[
152
+ item]['bpropFactors']
153
+
154
+ ########
155
+
156
+ ########
157
+ # concatenate the additive gradients along the batch dimension, i.e.
158
+ # assuming independence structure
159
+ for key in ['fpropFactors', 'bpropFactors']:
160
+ for i, param in enumerate(varlist):
161
+ if len(factorTensors[param][key]) > 0:
162
+ if (key + '_concat') not in factorTensors[param]:
163
+ name_scope = factorTensors[param][key][0].name.split(':')[
164
+ 0]
165
+ with tf.compat.v1.name_scope(name_scope):
166
+ factorTensors[param][
167
+ key + '_concat'] = tf.concat(factorTensors[param][key], 0)
168
+ else:
169
+ factorTensors[param][key + '_concat'] = None
170
+ for j, param2 in enumerate(varlist[(i + 1):]):
171
+ if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
172
+ factorTensors[param2][key] = factorTensors[param][key]
173
+ factorTensors[param2][
174
+ key + '_concat'] = factorTensors[param][key + '_concat']
175
+ ########
176
+
177
+ if KFAC_DEBUG:
178
+ for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
179
+ print((items[0].name, factorTensors[item]))
180
+ self.factors = factorTensors
181
+ return factorTensors
182
+
183
+ def getStats(self, factors, varlist):
184
+ if len(self.stats) == 0:
185
+ # initialize stats variables on CPU because eigen decomp is
186
+ # computed on CPU
187
+ with tf.device('/cpu'):
188
+ tmpStatsCache = {}
189
+
190
+ # search for tensor factors and
191
+ # use block diag approx for the bias units
192
+ for var in varlist:
193
+ fpropFactor = factors[var]['fpropFactors_concat']
194
+ bpropFactor = factors[var]['bpropFactors_concat']
195
+ opType = factors[var]['opName']
196
+ if opType == 'Conv2D':
197
+ Kh = var.get_shape()[0]
198
+ Kw = var.get_shape()[1]
199
+ C = fpropFactor.get_shape()[-1]
200
+
201
+ Oh = bpropFactor.get_shape()[1]
202
+ Ow = bpropFactor.get_shape()[2]
203
+ if Oh == 1 and Ow == 1 and self._channel_fac:
204
+ # factorization along the channels do not support
205
+ # homogeneous coordinate
206
+ var_assnBias = factors[var]['assnBias']
207
+ if var_assnBias:
208
+ factors[var]['assnBias'] = None
209
+ factors[var_assnBias]['assnWeights'] = None
210
+ ##
211
+
212
+ for var in varlist:
213
+ fpropFactor = factors[var]['fpropFactors_concat']
214
+ bpropFactor = factors[var]['bpropFactors_concat']
215
+ opType = factors[var]['opName']
216
+ self.stats[var] = {'opName': opType,
217
+ 'fprop_concat_stats': [],
218
+ 'bprop_concat_stats': [],
219
+ 'assnWeights': factors[var]['assnWeights'],
220
+ 'assnBias': factors[var]['assnBias'],
221
+ }
222
+ if fpropFactor is not None:
223
+ if fpropFactor not in tmpStatsCache:
224
+ if opType == 'Conv2D':
225
+ Kh = var.get_shape()[0]
226
+ Kw = var.get_shape()[1]
227
+ C = fpropFactor.get_shape()[-1]
228
+
229
+ Oh = bpropFactor.get_shape()[1]
230
+ Ow = bpropFactor.get_shape()[2]
231
+ if Oh == 1 and Ow == 1 and self._channel_fac:
232
+ # factorization along the channels
233
+ # assume independence between input channels and spatial
234
+ # 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
235
+ # factorization along the channels do not
236
+ # support homogeneous coordinate, assnBias
237
+ # is always None
238
+ fpropFactor2_size = Kh * Kw
239
+ slot_fpropFactor_stats2 = tf.Variable(tf.linalg.tensor_diag(tf.ones(
240
+ [fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
241
+ self.stats[var]['fprop_concat_stats'].append(
242
+ slot_fpropFactor_stats2)
243
+
244
+ fpropFactor_size = C
245
+ else:
246
+ # 2K-1 x 2K-1 x C x C covariance matrix
247
+ # assume BHWC
248
+ fpropFactor_size = Kh * Kw * C
249
+ else:
250
+ # D x D covariance matrix
251
+ fpropFactor_size = fpropFactor.get_shape()[-1]
252
+
253
+ # use homogeneous coordinate
254
+ if not self._blockdiag_bias and self.stats[var]['assnBias']:
255
+ fpropFactor_size += 1
256
+
257
+ slot_fpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones(
258
+ [fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
259
+ self.stats[var]['fprop_concat_stats'].append(
260
+ slot_fpropFactor_stats)
261
+ if opType != 'Conv2D':
262
+ tmpStatsCache[fpropFactor] = self.stats[
263
+ var]['fprop_concat_stats']
264
+ else:
265
+ self.stats[var][
266
+ 'fprop_concat_stats'] = tmpStatsCache[fpropFactor]
267
+
268
+ if bpropFactor is not None:
269
+ # no need to collect backward stats for bias vectors if
270
+ # using homogeneous coordinates
271
+ if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
272
+ if bpropFactor not in tmpStatsCache:
273
+ slot_bpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones([bpropFactor.get_shape(
274
+ )[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
275
+ self.stats[var]['bprop_concat_stats'].append(
276
+ slot_bpropFactor_stats)
277
+ tmpStatsCache[bpropFactor] = self.stats[
278
+ var]['bprop_concat_stats']
279
+ else:
280
+ self.stats[var][
281
+ 'bprop_concat_stats'] = tmpStatsCache[bpropFactor]
282
+
283
+ return self.stats
284
+
285
+ def compute_and_apply_stats(self, loss_sampled, var_list=None):
286
+ varlist = var_list
287
+ if varlist is None:
288
+ varlist = tf.compat.v1.trainable_variables()
289
+
290
+ stats = self.compute_stats(loss_sampled, var_list=varlist)
291
+ return self.apply_stats(stats)
292
+
293
+ def compute_stats(self, loss_sampled, var_list=None):
294
+ varlist = var_list
295
+ if varlist is None:
296
+ varlist = tf.compat.v1.trainable_variables()
297
+
298
+ gs = tf.gradients(ys=loss_sampled, xs=varlist, name='gradientsSampled')
299
+ self.gs = gs
300
+ factors = self.getFactors(gs, varlist)
301
+ stats = self.getStats(factors, varlist)
302
+
303
+ updateOps = []
304
+ statsUpdates = {}
305
+ statsUpdates_cache = {}
306
+ for var in varlist:
307
+ opType = factors[var]['opName']
308
+ fops = factors[var]['op']
309
+ fpropFactor = factors[var]['fpropFactors_concat']
310
+ fpropStats_vars = stats[var]['fprop_concat_stats']
311
+ bpropFactor = factors[var]['bpropFactors_concat']
312
+ bpropStats_vars = stats[var]['bprop_concat_stats']
313
+ SVD_factors = {}
314
+ for stats_var in fpropStats_vars:
315
+ stats_var_dim = int(stats_var.get_shape()[0])
316
+ if stats_var not in statsUpdates_cache:
317
+ old_fpropFactor = fpropFactor
318
+ B = (tf.shape(input=fpropFactor)[0]) # batch size
319
+ if opType == 'Conv2D':
320
+ strides = fops.get_attr("strides")
321
+ padding = fops.get_attr("padding")
322
+ convkernel_size = var.get_shape()[0:3]
323
+
324
+ KH = int(convkernel_size[0])
325
+ KW = int(convkernel_size[1])
326
+ C = int(convkernel_size[2])
327
+ flatten_size = int(KH * KW * C)
328
+
329
+ Oh = int(bpropFactor.get_shape()[1])
330
+ Ow = int(bpropFactor.get_shape()[2])
331
+
332
+ if Oh == 1 and Ow == 1 and self._channel_fac:
333
+ # factorization along the channels
334
+ # assume independence among input channels
335
+ # factor = B x 1 x 1 x (KH xKW x C)
336
+ # patches = B x Oh x Ow x (KH xKW x C)
337
+ if len(SVD_factors) == 0:
338
+ if KFAC_DEBUG:
339
+ print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
340
+ # find closest rank-1 approx to the feature map
341
+ S, U, V = tf.batch_svd(tf.reshape(
342
+ fpropFactor, [-1, KH * KW, C]))
343
+ # get rank-1 approx slides
344
+ sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
345
+ patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW
346
+ full_factor_shape = fpropFactor.get_shape()
347
+ patches_k.set_shape(
348
+ [full_factor_shape[0], KH * KW])
349
+ patches_c = V[:, :, 0] * sqrtS1 # B x C
350
+ patches_c.set_shape([full_factor_shape[0], C])
351
+ SVD_factors[C] = patches_c
352
+ SVD_factors[KH * KW] = patches_k
353
+ fpropFactor = SVD_factors[stats_var_dim]
354
+
355
+ else:
356
+ # poor mem usage implementation
357
+ patches = tf.image.extract_patches(fpropFactor, sizes=[1, convkernel_size[
358
+ 0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)
359
+
360
+ if self._approxT2:
361
+ if KFAC_DEBUG:
362
+ print(('approxT2 act fisher for %s' % (var.name)))
363
+ # T^2 terms * 1/T^2, size: B x C
364
+ fpropFactor = tf.reduce_mean(input_tensor=patches, axis=[1, 2])
365
+ else:
366
+ # size: (B x Oh x Ow) x C
367
+ fpropFactor = tf.reshape(
368
+ patches, [-1, flatten_size]) / Oh / Ow
369
+ fpropFactor_size = int(fpropFactor.get_shape()[-1])
370
+ if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
371
+ if opType == 'Conv2D' and not self._approxT2:
372
+ # correct padding for numerical stability (we
373
+ # divided out OhxOw from activations for T1 approx)
374
+ fpropFactor = tf.concat([fpropFactor, tf.ones(
375
+ [tf.shape(input=fpropFactor)[0], 1]) / Oh / Ow], 1)
376
+ else:
377
+ # use homogeneous coordinates
378
+ fpropFactor = tf.concat(
379
+ [fpropFactor, tf.ones([tf.shape(input=fpropFactor)[0], 1])], 1)
380
+
381
+ # average over the number of data points in a batch
382
+ # divided by B
383
+ cov = tf.matmul(fpropFactor, fpropFactor,
384
+ transpose_a=True) / tf.cast(B, tf.float32)
385
+ updateOps.append(cov)
386
+ statsUpdates[stats_var] = cov
387
+ if opType != 'Conv2D':
388
+ # HACK: for convolution we recompute fprop stats for
389
+ # every layer including forking layers
390
+ statsUpdates_cache[stats_var] = cov
391
+
392
+ for stats_var in bpropStats_vars:
393
+ stats_var_dim = int(stats_var.get_shape()[0])
394
+ if stats_var not in statsUpdates_cache:
395
+ old_bpropFactor = bpropFactor
396
+ bpropFactor_shape = bpropFactor.get_shape()
397
+ B = tf.shape(input=bpropFactor)[0] # batch size
398
+ C = int(bpropFactor_shape[-1]) # num channels
399
+ if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
400
+ if fpropFactor is not None:
401
+ if self._approxT2:
402
+ if KFAC_DEBUG:
403
+ print(('approxT2 grad fisher for %s' % (var.name)))
404
+ bpropFactor = tf.reduce_sum(
405
+ input_tensor=bpropFactor, axis=[1, 2]) # T^2 terms * 1/T^2
406
+ else:
407
+ bpropFactor = tf.reshape(
408
+ bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms
409
+ else:
410
+ # just doing block diag approx. spatial independent
411
+ # structure does not apply here. summing over
412
+ # spatial locations
413
+ if KFAC_DEBUG:
414
+ print(('block diag approx fisher for %s' % (var.name)))
415
+ bpropFactor = tf.reduce_sum(input_tensor=bpropFactor, axis=[1, 2])
416
+
417
+ # assume sampled loss is averaged. TO-DO:figure out better
418
+ # way to handle this
419
+ bpropFactor *= tf.cast(B, dtype=tf.float32)
420
+ ##
421
+
422
+ cov_b = tf.matmul(
423
+ bpropFactor, bpropFactor, transpose_a=True) / tf.cast(tf.shape(input=bpropFactor)[0], dtype=tf.float32)
424
+
425
+ updateOps.append(cov_b)
426
+ statsUpdates[stats_var] = cov_b
427
+ statsUpdates_cache[stats_var] = cov_b
428
+
429
+ if KFAC_DEBUG:
430
+ aKey = list(statsUpdates.keys())[0]
431
+ statsUpdates[aKey] = tf.compat.v1.Print(statsUpdates[aKey],
432
+ [tf.convert_to_tensor(value='step:'),
433
+ self.global_step,
434
+ tf.convert_to_tensor(
435
+ value='computing stats'),
436
+ ])
437
+ self.statsUpdates = statsUpdates
438
+ return statsUpdates
439
+
440
+ def apply_stats(self, statsUpdates):
441
+ """ compute stats and update/apply the new stats to the running average
442
+ """
443
+
444
+ def updateAccumStats():
445
+ if self._full_stats_init:
446
+ return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), false_fn=tf.no_op)
447
+ else:
448
+ return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))
449
+
450
+ def updateRunningAvgStats(statsUpdates, fac_iter=1):
451
+ # return tf.cond(tf.greater_equal(self.factor_step,
452
+ # tf.convert_to_tensor(fac_iter)), lambda:
453
+ # tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
454
+ return tf.group(*self._apply_stats(statsUpdates))
455
+
456
+ if self._async_stats:
457
+ # asynchronous stats update
458
+ update_stats = self._apply_stats(statsUpdates)
459
+
460
+ queue = tf.queue.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
461
+ item.get_shape() for item in update_stats])
462
+ enqueue_op = queue.enqueue(update_stats)
463
+
464
+ def dequeue_stats_op():
465
+ return queue.dequeue()
466
+ self.qr_stats = tf.compat.v1.train.QueueRunner(queue, [enqueue_op])
467
+ update_stats_op = tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor(
468
+ value=0)), true_fn=tf.no_op, false_fn=lambda: tf.group(*[dequeue_stats_op(), ]))
469
+ else:
470
+ # synchronous stats update
471
+ update_stats_op = tf.cond(pred=tf.greater_equal(
472
+ self.stats_step, self._stats_accum_iter), true_fn=lambda: updateRunningAvgStats(statsUpdates), false_fn=updateAccumStats)
473
+ self._update_stats_op = update_stats_op
474
+ return update_stats_op
475
+
476
+ def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
477
+ updateOps = []
478
+ # obtain the stats var list
479
+ for stats_var in statsUpdates:
480
+ stats_new = statsUpdates[stats_var]
481
+ if accumulate:
482
+ # simple superbatch averaging
483
+ update_op = tf.compat.v1.assign_add(
484
+ stats_var, accumulateCoeff * stats_new, use_locking=True)
485
+ else:
486
+ # exponential running averaging
487
+ update_op = tf.compat.v1.assign(
488
+ stats_var, stats_var * self._stats_decay, use_locking=True)
489
+ update_op = tf.compat.v1.assign_add(
490
+ update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
491
+ updateOps.append(update_op)
492
+
493
+ with tf.control_dependencies(updateOps):
494
+ stats_step_op = tf.compat.v1.assign_add(self.stats_step, 1)
495
+
496
+ if KFAC_DEBUG:
497
+ stats_step_op = (tf.compat.v1.Print(stats_step_op,
498
+ [tf.convert_to_tensor(value='step:'),
499
+ self.global_step,
500
+ tf.convert_to_tensor(value='fac step:'),
501
+ self.factor_step,
502
+ tf.convert_to_tensor(value='sgd step:'),
503
+ self.sgd_step,
504
+ tf.convert_to_tensor(value='Accum:'),
505
+ tf.convert_to_tensor(value=accumulate),
506
+ tf.convert_to_tensor(value='Accum coeff:'),
507
+ tf.convert_to_tensor(value=accumulateCoeff),
508
+ tf.convert_to_tensor(value='stat step:'),
509
+ self.stats_step, updateOps[0], updateOps[1]]))
510
+ return [stats_step_op, ]
511
+
512
+ def getStatsEigen(self, stats=None):
513
+ if len(self.stats_eigen) == 0:
514
+ stats_eigen = {}
515
+ if stats is None:
516
+ stats = self.stats
517
+
518
+ tmpEigenCache = {}
519
+ with tf.device('/cpu:0'):
520
+ for var in stats:
521
+ for key in ['fprop_concat_stats', 'bprop_concat_stats']:
522
+ for stats_var in stats[var][key]:
523
+ if stats_var not in tmpEigenCache:
524
+ stats_dim = stats_var.get_shape()[1].value
525
+ e = tf.Variable(tf.ones(
526
+ [stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
527
+ Q = tf.Variable(tf.linalg.tensor_diag(tf.ones(
528
+ [stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
529
+ stats_eigen[stats_var] = {'e': e, 'Q': Q}
530
+ tmpEigenCache[
531
+ stats_var] = stats_eigen[stats_var]
532
+ else:
533
+ stats_eigen[stats_var] = tmpEigenCache[
534
+ stats_var]
535
+ self.stats_eigen = stats_eigen
536
+ return self.stats_eigen
537
+
538
+ def computeStatsEigen(self):
539
+ """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
540
+ # TO-DO: figure out why this op has delays (possibly moving
541
+ # eigenvectors around?)
542
+ with tf.device('/cpu:0'):
543
+ def removeNone(tensor_list):
544
+ local_list = []
545
+ for item in tensor_list:
546
+ if item is not None:
547
+ local_list.append(item)
548
+ return local_list
549
+
550
+ def copyStats(var_list):
551
+ print("copying stats to buffer tensors before eigen decomp")
552
+ redundant_stats = {}
553
+ copied_list = []
554
+ for item in var_list:
555
+ if item is not None:
556
+ if item not in redundant_stats:
557
+ if self._use_float64:
558
+ redundant_stats[item] = tf.cast(
559
+ tf.identity(item), tf.float64)
560
+ else:
561
+ redundant_stats[item] = tf.identity(item)
562
+ copied_list.append(redundant_stats[item])
563
+ else:
564
+ copied_list.append(None)
565
+ return copied_list
566
+ #stats = [copyStats(self.fStats), copyStats(self.bStats)]
567
+ #stats = [self.fStats, self.bStats]
568
+
569
+ stats_eigen = self.stats_eigen
570
+ computedEigen = {}
571
+ eigen_reverse_lookup = {}
572
+ updateOps = []
573
+ # sync copied stats
574
+ # with tf.control_dependencies(removeNone(stats[0]) +
575
+ # removeNone(stats[1])):
576
+ with tf.control_dependencies([]):
577
+ for stats_var in stats_eigen:
578
+ if stats_var not in computedEigen:
579
+ eigens = tf.linalg.eigh(stats_var)
580
+ e = eigens[0]
581
+ Q = eigens[1]
582
+ if self._use_float64:
583
+ e = tf.cast(e, tf.float32)
584
+ Q = tf.cast(Q, tf.float32)
585
+ updateOps.append(e)
586
+ updateOps.append(Q)
587
+ computedEigen[stats_var] = {'e': e, 'Q': Q}
588
+ eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
589
+ eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']
590
+
591
+ self.eigen_reverse_lookup = eigen_reverse_lookup
592
+ self.eigen_update_list = updateOps
593
+
594
+ if KFAC_DEBUG:
595
+ self.eigen_update_list = [item for item in updateOps]
596
+ with tf.control_dependencies(updateOps):
597
+ updateOps.append(tf.compat.v1.Print(tf.constant(
598
+ 0.), [tf.convert_to_tensor(value='computed factor eigen')]))
599
+
600
+ return updateOps
601
+
602
+ def applyStatsEigen(self, eigen_list):
603
+ updateOps = []
604
+ print(('updating %d eigenvalue/vectors' % len(eigen_list)))
605
+ for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
606
+ stats_eigen_var = self.eigen_reverse_lookup[mark]
607
+ updateOps.append(
608
+ tf.compat.v1.assign(stats_eigen_var, tensor, use_locking=True))
609
+
610
+ with tf.control_dependencies(updateOps):
611
+ factor_step_op = tf.compat.v1.assign_add(self.factor_step, 1)
612
+ updateOps.append(factor_step_op)
613
+ if KFAC_DEBUG:
614
+ updateOps.append(tf.compat.v1.Print(tf.constant(
615
+ 0.), [tf.convert_to_tensor(value='updated kfac factors')]))
616
+ return updateOps
617
+
618
+ def getKfacPrecondUpdates(self, gradlist, varlist):
619
+ updatelist = []
620
+ vg = 0.
621
+
622
+ assert len(self.stats) > 0
623
+ assert len(self.stats_eigen) > 0
624
+ assert len(self.factors) > 0
625
+ counter = 0
626
+
627
+ grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}
628
+
629
+ for grad, var in zip(gradlist, varlist):
630
+ GRAD_RESHAPE = False
631
+ GRAD_TRANSPOSE = False
632
+
633
+ fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
634
+ bpropFactoredFishers = self.stats[var]['bprop_concat_stats']
635
+
636
+ if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
637
+ counter += 1
638
+ GRAD_SHAPE = grad.get_shape()
639
+ if len(grad.get_shape()) > 2:
640
+ # reshape conv kernel parameters
641
+ KW = int(grad.get_shape()[0])
642
+ KH = int(grad.get_shape()[1])
643
+ C = int(grad.get_shape()[2])
644
+ D = int(grad.get_shape()[3])
645
+
646
+ if len(fpropFactoredFishers) > 1 and self._channel_fac:
647
+ # reshape conv kernel parameters into tensor
648
+ grad = tf.reshape(grad, [KW * KH, C, D])
649
+ else:
650
+ # reshape conv kernel parameters into 2D grad
651
+ grad = tf.reshape(grad, [-1, D])
652
+ GRAD_RESHAPE = True
653
+ elif len(grad.get_shape()) == 1:
654
+ # reshape bias or 1D parameters
655
+ D = int(grad.get_shape()[0])
656
+
657
+ grad = tf.expand_dims(grad, 0)
658
+ GRAD_RESHAPE = True
659
+ else:
660
+ # 2D parameters
661
+ C = int(grad.get_shape()[0])
662
+ D = int(grad.get_shape()[1])
663
+
664
+ if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
665
+ # use homogeneous coordinates only works for 2D grad.
666
+ # TO-DO: figure out how to factorize bias grad
667
+ # stack bias grad
668
+ var_assnBias = self.stats[var]['assnBias']
669
+ grad = tf.concat(
670
+ [grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)
671
+
672
+ # project gradient to eigen space and reshape the eigenvalues
673
+ # for broadcasting
674
+ eigVals = []
675
+
676
+ for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
677
+ Q = self.stats_eigen[stats]['Q']
678
+ e = detectMinVal(self.stats_eigen[stats][
679
+ 'e'], var, name='act', debug=KFAC_DEBUG)
680
+
681
+ Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
682
+ eigVals.append(e)
683
+ grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)
684
+
685
+ for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
686
+ Q = self.stats_eigen[stats]['Q']
687
+ e = detectMinVal(self.stats_eigen[stats][
688
+ 'e'], var, name='grad', debug=KFAC_DEBUG)
689
+
690
+ Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
691
+ eigVals.append(e)
692
+ grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
693
+ ##
694
+
695
+ #####
696
+ # whiten using eigenvalues
697
+ weightDecayCoeff = 0.
698
+ if var in self._weight_decay_dict:
699
+ weightDecayCoeff = self._weight_decay_dict[var]
700
+ if KFAC_DEBUG:
701
+ print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))
702
+
703
+ if self._factored_damping:
704
+ if KFAC_DEBUG:
705
+ print(('use factored damping for %s' % (var.name)))
706
+ coeffs = 1.
707
+ num_factors = len(eigVals)
708
+ # compute the ratio of two trace norm of the left and right
709
+ # KFac matrices, and their generalization
710
+ if len(eigVals) == 1:
711
+ damping = self._epsilon + weightDecayCoeff
712
+ else:
713
+ damping = tf.pow(
714
+ self._epsilon + weightDecayCoeff, 1. / num_factors)
715
+ eigVals_tnorm_avg = [tf.reduce_mean(
716
+ input_tensor=tf.abs(e)) for e in eigVals]
717
+ for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
718
+ eig_tnorm_negList = [
719
+ item for item in eigVals_tnorm_avg if item != e_tnorm]
720
+ if len(eigVals) == 1:
721
+ adjustment = 1.
722
+ elif len(eigVals) == 2:
723
+ adjustment = tf.sqrt(
724
+ e_tnorm / eig_tnorm_negList[0])
725
+ else:
726
+ eig_tnorm_negList_prod = reduce(
727
+ lambda x, y: x * y, eig_tnorm_negList)
728
+ adjustment = tf.pow(
729
+ tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
730
+ coeffs *= (e + adjustment * damping)
731
+ else:
732
+ coeffs = 1.
733
+ damping = (self._epsilon + weightDecayCoeff)
734
+ for e in eigVals:
735
+ coeffs *= e
736
+ coeffs += damping
737
+
738
+ #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])
739
+
740
+ grad /= coeffs
741
+
742
+ #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
743
+ #####
744
+ # project gradient back to euclidean space
745
+ for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
746
+ Q = self.stats_eigen[stats]['Q']
747
+ grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)
748
+
749
+ for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
750
+ Q = self.stats_eigen[stats]['Q']
751
+ grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
752
+ ##
753
+
754
+ #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
755
+ if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
756
+ # use homogeneous coordinates only works for 2D grad.
757
+ # TO-DO: figure out how to factorize bias grad
758
+ # un-stack bias grad
759
+ var_assnBias = self.stats[var]['assnBias']
760
+ C_plus_one = int(grad.get_shape()[0])
761
+ grad_assnBias = tf.reshape(tf.slice(grad,
762
+ begin=[
763
+ C_plus_one - 1, 0],
764
+ size=[1, -1]), var_assnBias.get_shape())
765
+ grad_assnWeights = tf.slice(grad,
766
+ begin=[0, 0],
767
+ size=[C_plus_one - 1, -1])
768
+ grad_dict[var_assnBias] = grad_assnBias
769
+ grad = grad_assnWeights
770
+
771
+ #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
772
+ if GRAD_RESHAPE:
773
+ grad = tf.reshape(grad, GRAD_SHAPE)
774
+
775
+ grad_dict[var] = grad
776
+
777
+ print(('projecting %d gradient matrices' % counter))
778
+
779
+ for g, var in zip(gradlist, varlist):
780
+ grad = grad_dict[var]
781
+ ### clipping ###
782
+ if KFAC_DEBUG:
783
+ print(('apply clipping to %s' % (var.name)))
784
+ tf.compat.v1.Print(grad, [tf.sqrt(tf.reduce_sum(input_tensor=tf.pow(grad, 2)))], "Euclidean norm of new grad")
785
+ local_vg = tf.reduce_sum(input_tensor=grad * g * (self._lr * self._lr))
786
+ vg += local_vg
787
+
788
+ # recale everything
789
+ if KFAC_DEBUG:
790
+ print('apply vFv clipping')
791
+
792
+ scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
793
+ if KFAC_DEBUG:
794
+ scaling = tf.compat.v1.Print(scaling, [tf.convert_to_tensor(
795
+ value='clip: '), scaling, tf.convert_to_tensor(value=' vFv: '), vg])
796
+ with tf.control_dependencies([tf.compat.v1.assign(self.vFv, vg)]):
797
+ updatelist = [grad_dict[var] for var in varlist]
798
+ for i, item in enumerate(updatelist):
799
+ updatelist[i] = scaling * item
800
+
801
+ return updatelist
802
+
803
+ def compute_gradients(self, loss, var_list=None):
804
+ varlist = var_list
805
+ if varlist is None:
806
+ varlist = tf.compat.v1.trainable_variables()
807
+ g = tf.gradients(ys=loss, xs=varlist)
808
+
809
+ return [(a, b) for a, b in zip(g, varlist)]
810
+
811
+ def apply_gradients_kfac(self, grads):
812
+ g, varlist = list(zip(*grads))
813
+
814
+ if len(self.stats_eigen) == 0:
815
+ self.getStatsEigen()
816
+
817
+ qr = None
818
+ # launch eigen-decomp on a queue thread
819
+ if self._async:
820
+ print('Use async eigen decomp')
821
+ # get a list of factor loading tensors
822
+ factorOps_dummy = self.computeStatsEigen()
823
+
824
+ # define a queue for the list of factor loading tensors
825
+ queue = tf.queue.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
826
+ item.get_shape() for item in factorOps_dummy])
827
+ enqueue_op = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
828
+ value=0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: queue.enqueue(self.computeStatsEigen()), false_fn=tf.no_op)
829
+
830
+ def dequeue_op():
831
+ return queue.dequeue()
832
+
833
+ qr = tf.compat.v1.train.QueueRunner(queue, [enqueue_op])
834
+
835
+ updateOps = []
836
+ global_step_op = tf.compat.v1.assign_add(self.global_step, 1)
837
+ updateOps.append(global_step_op)
838
+
839
+ with tf.control_dependencies([global_step_op]):
840
+
841
+ # compute updates
842
+ assert self._update_stats_op != None
843
+ updateOps.append(self._update_stats_op)
844
+ dependency_list = []
845
+ if not self._async:
846
+ dependency_list.append(self._update_stats_op)
847
+
848
+ with tf.control_dependencies(dependency_list):
849
+ def no_op_wrapper():
850
+ return tf.group(*[tf.compat.v1.assign_add(self.cold_step, 1)])
851
+
852
+ if not self._async:
853
+ # synchronous eigen-decomp updates
854
+ updateFactorOps = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update),
855
+ tf.convert_to_tensor(value=0)),
856
+ tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), false_fn=no_op_wrapper)
857
+ else:
858
+ # asynchronous eigen-decomp updates using queue
859
+ updateFactorOps = tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter),
860
+ true_fn=lambda: tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor(value=0)),
861
+ true_fn=tf.no_op,
862
+
863
+ false_fn=lambda: tf.group(
864
+ *self.applyStatsEigen(dequeue_op())),
865
+ ),
866
+ false_fn=no_op_wrapper)
867
+
868
+ updateOps.append(updateFactorOps)
869
+
870
+ with tf.control_dependencies([updateFactorOps]):
871
+ def gradOp():
872
+ return list(g)
873
+
874
+ def getKfacGradOp():
875
+ return self.getKfacPrecondUpdates(g, varlist)
876
+ u = tf.cond(pred=tf.greater(self.factor_step,
877
+ tf.convert_to_tensor(value=0)), true_fn=getKfacGradOp, false_fn=gradOp)
878
+
879
+ optim = tf.compat.v1.train.MomentumOptimizer(
880
+ self._lr * (1. - self._momentum), self._momentum)
881
+ #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)
882
+
883
+ def optimOp():
884
+ def updateOptimOp():
885
+ if self._full_stats_init:
886
+ return tf.cond(pred=tf.greater(self.factor_step, tf.convert_to_tensor(value=0)), true_fn=lambda: optim.apply_gradients(list(zip(u, varlist))), false_fn=tf.no_op)
887
+ else:
888
+ return optim.apply_gradients(list(zip(u, varlist)))
889
+ if self._full_stats_init:
890
+ return tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter), true_fn=updateOptimOp, false_fn=tf.no_op)
891
+ else:
892
+ return tf.cond(pred=tf.greater_equal(self.sgd_step, self._cold_iter), true_fn=updateOptimOp, false_fn=tf.no_op)
893
+ updateOps.append(optimOp())
894
+
895
+ return tf.group(*updateOps), qr
896
+
897
+ def apply_gradients(self, grads):
898
+ coldOptim = tf.compat.v1.train.MomentumOptimizer(
899
+ self._cold_lr, self._momentum)
900
+
901
+ def coldSGDstart():
902
+ sgd_grads, sgd_var = zip(*grads)
903
+
904
+ if self.max_grad_norm != None:
905
+ sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)
906
+
907
+ sgd_grads = list(zip(sgd_grads,sgd_var))
908
+
909
+ sgd_step_op = tf.compat.v1.assign_add(self.sgd_step, 1)
910
+ coldOptim_op = coldOptim.apply_gradients(sgd_grads)
911
+ if KFAC_DEBUG:
912
+ with tf.control_dependencies([sgd_step_op, coldOptim_op]):
913
+ sgd_step_op = tf.compat.v1.Print(
914
+ sgd_step_op, [self.sgd_step, tf.convert_to_tensor(value='doing cold sgd step')])
915
+ return tf.group(*[sgd_step_op, coldOptim_op])
916
+
917
+ kfacOptim_op, qr = self.apply_gradients_kfac(grads)
918
+
919
+ def warmKFACstart():
920
+ return kfacOptim_op
921
+
922
+ return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=warmKFACstart, false_fn=coldSGDstart), qr
923
+
924
+ def minimize(self, loss, loss_sampled, var_list=None):
925
+ grads = self.compute_gradients(loss, var_list=var_list)
926
+ update_stats_op = self.compute_and_apply_stats(
927
+ loss_sampled, var_list=var_list)
928
+ return self.apply_gradients(grads)
baselines/acktr/kfac_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ def gmatmul(a, b, transpose_a=False, transpose_b=False, reduce_dim=None):
4
+ assert reduce_dim is not None
5
+
6
+ # weird batch matmul
7
+ if len(a.get_shape()) == 2 and len(b.get_shape()) > 2:
8
+ # reshape reduce_dim to the left most dim in b
9
+ b_shape = b.get_shape()
10
+ if reduce_dim != 0:
11
+ b_dims = list(range(len(b_shape)))
12
+ b_dims.remove(reduce_dim)
13
+ b_dims.insert(0, reduce_dim)
14
+ b = tf.transpose(a=b, perm=b_dims)
15
+ b_t_shape = b.get_shape()
16
+ b = tf.reshape(b, [int(b_shape[reduce_dim]), -1])
17
+ result = tf.matmul(a, b, transpose_a=transpose_a,
18
+ transpose_b=transpose_b)
19
+ result = tf.reshape(result, b_t_shape)
20
+ if reduce_dim != 0:
21
+ b_dims = list(range(len(b_shape)))
22
+ b_dims.remove(0)
23
+ b_dims.insert(reduce_dim, 0)
24
+ result = tf.transpose(a=result, perm=b_dims)
25
+ return result
26
+
27
+ elif len(a.get_shape()) > 2 and len(b.get_shape()) == 2:
28
+ # reshape reduce_dim to the right most dim in a
29
+ a_shape = a.get_shape()
30
+ outter_dim = len(a_shape) - 1
31
+ reduce_dim = len(a_shape) - reduce_dim - 1
32
+ if reduce_dim != outter_dim:
33
+ a_dims = list(range(len(a_shape)))
34
+ a_dims.remove(reduce_dim)
35
+ a_dims.insert(outter_dim, reduce_dim)
36
+ a = tf.transpose(a=a, perm=a_dims)
37
+ a_t_shape = a.get_shape()
38
+ a = tf.reshape(a, [-1, int(a_shape[reduce_dim])])
39
+ result = tf.matmul(a, b, transpose_a=transpose_a,
40
+ transpose_b=transpose_b)
41
+ result = tf.reshape(result, a_t_shape)
42
+ if reduce_dim != outter_dim:
43
+ a_dims = list(range(len(a_shape)))
44
+ a_dims.remove(outter_dim)
45
+ a_dims.insert(reduce_dim, outter_dim)
46
+ result = tf.transpose(a=result, perm=a_dims)
47
+ return result
48
+
49
+ elif len(a.get_shape()) == 2 and len(b.get_shape()) == 2:
50
+ return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b)
51
+
52
+ assert False, 'something went wrong'
53
+
54
+
55
+ def clipoutNeg(vec, threshold=1e-6):
56
+ mask = tf.cast(vec > threshold, tf.float32)
57
+ return mask * vec
58
+
59
+
60
+ def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False):
61
+ eigen_min = tf.reduce_min(input_tensor=input_mat)
62
+ eigen_max = tf.reduce_max(input_tensor=input_mat)
63
+ eigen_ratio = eigen_max / eigen_min
64
+ input_mat_clipped = clipoutNeg(input_mat, threshold)
65
+
66
+ if debug:
67
+ input_mat_clipped = tf.cond(pred=tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), true_fn=lambda: input_mat_clipped, false_fn=lambda: tf.compat.v1.Print(
68
+ input_mat_clipped, [tf.convert_to_tensor(value='screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(value=var.name), eigen_min, eigen_max, eigen_ratio]))
69
+
70
+ return input_mat_clipped
71
+
72
+
73
+ def factorReshape(Q, e, grad, facIndx=0, ftype='act'):
74
+ grad_shape = grad.get_shape()
75
+ if ftype == 'act':
76
+ assert e.get_shape()[0] == grad_shape[facIndx]
77
+ expanded_shape = [1, ] * len(grad_shape)
78
+ expanded_shape[facIndx] = -1
79
+ e = tf.reshape(e, expanded_shape)
80
+ if ftype == 'grad':
81
+ assert e.get_shape()[0] == grad_shape[len(grad_shape) - facIndx - 1]
82
+ expanded_shape = [1, ] * len(grad_shape)
83
+ expanded_shape[len(grad_shape) - facIndx - 1] = -1
84
+ e = tf.reshape(e, expanded_shape)
85
+
86
+ return Q, e
baselines/acktr/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None):
4
+ with tf.compat.v1.variable_scope(name, reuse=reuse):
5
+ assert (len(tf.compat.v1.get_variable_scope().name.split('/')) == 2)
6
+
7
+ w = tf.compat.v1.get_variable("w", [x.get_shape()[1], size], initializer=weight_init)
8
+ b = tf.compat.v1.get_variable("b", [size], initializer=tf.compat.v1.constant_initializer(bias_init))
9
+ weight_decay_fc = 3e-4
10
+
11
+ if weight_loss_dict is not None:
12
+ weight_decay = tf.multiply(tf.nn.l2_loss(w), weight_decay_fc, name='weight_decay_loss')
13
+ if weight_loss_dict is not None:
14
+ weight_loss_dict[w] = weight_decay_fc
15
+ weight_loss_dict[b] = 0.0
16
+
17
+ tf.compat.v1.add_to_collection(tf.compat.v1.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay)
18
+
19
+ return tf.nn.bias_add(tf.matmul(x, w), b)
20
+
21
+ def kl_div(action_dist1, action_dist2, action_size):
22
+ mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:]
23
+ mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:]
24
+
25
+ numerator = tf.square(mean1 - mean2) + tf.square(std1) - tf.square(std2)
26
+ denominator = 2 * tf.square(std2) + 1e-8
27
+ return tf.reduce_sum(
28
+ input_tensor=numerator/denominator + tf.math.log(std2) - tf.math.log(std1),axis=-1)
baselines/bench/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # flake8: noqa F403
2
+ from baselines.bench.benchmarks import *
3
+ from baselines.bench.monitor import *
baselines/bench/benchmarks.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+
5
+ _atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders']
6
+ _atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture']
7
+
8
+ _BENCHMARKS = []
9
+
10
+ remove_version_re = re.compile(r'-v\d+$')
11
+
12
+
13
+ def register_benchmark(benchmark):
14
+ for b in _BENCHMARKS:
15
+ if b['name'] == benchmark['name']:
16
+ raise ValueError('Benchmark with name %s already registered!' % b['name'])
17
+
18
+ # automatically add a description if it is not present
19
+ if 'tasks' in benchmark:
20
+ for t in benchmark['tasks']:
21
+ if 'desc' not in t:
22
+ t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
23
+ _BENCHMARKS.append(benchmark)
24
+
25
+
26
+ def list_benchmarks():
27
+ return [b['name'] for b in _BENCHMARKS]
28
+
29
+
30
+ def get_benchmark(benchmark_name):
31
+ for b in _BENCHMARKS:
32
+ if b['name'] == benchmark_name:
33
+ return b
34
+ raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks()))
35
+
36
+
37
+ def get_task(benchmark, env_id):
38
+ """Get a task by env_id. Return None if the benchmark doesn't have the env"""
39
+ return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None)
40
+
41
+
42
+ def find_task_for_env_id_in_any_benchmark(env_id):
43
+ for bm in _BENCHMARKS:
44
+ for task in bm["tasks"]:
45
+ if task["env_id"] == env_id:
46
+ return bm, task
47
+ return None, None
48
+
49
+
50
+ _ATARI_SUFFIX = 'NoFrameskip-v4'
51
+
52
+ register_benchmark({
53
+ 'name': 'Atari50M',
54
+ 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 50M timesteps',
55
+ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(50e6)} for _game in _atari7]
56
+ })
57
+
58
+ register_benchmark({
59
+ 'name': 'Atari10M',
60
+ 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps',
61
+ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 6, 'num_timesteps': int(10e6)} for _game in _atari7]
62
+ })
63
+
64
+ register_benchmark({
65
+ 'name': 'Atari1Hr',
66
+ 'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime',
67
+ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_seconds': 60 * 60} for _game in _atari7]
68
+ })
69
+
70
+ register_benchmark({
71
+ 'name': 'AtariExploration10M',
72
+ 'description': '7 Atari games emphasizing exploration, with pixel observations, 10M timesteps',
73
+ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atariexpl7]
74
+ })
75
+
76
+
77
+ # MuJoCo
78
+
79
+ _mujocosmall = [
80
+ 'InvertedDoublePendulum-v2', 'InvertedPendulum-v2',
81
+ 'HalfCheetah-v2', 'Hopper-v2', 'Walker2d-v2',
82
+ 'Reacher-v2', 'Swimmer-v2']
83
+ register_benchmark({
84
+ 'name': 'Mujoco1M',
85
+ 'description': 'Some small 2D MuJoCo tasks, run for 1M timesteps',
86
+ 'tasks': [{'env_id': _envid, 'trials': 6, 'num_timesteps': int(1e6)} for _envid in _mujocosmall]
87
+ })
88
+
89
+ register_benchmark({
90
+ 'name': 'MujocoWalkers',
91
+ 'description': 'MuJoCo forward walkers, run for 8M, humanoid 100M',
92
+ 'tasks': [
93
+ {'env_id': "Hopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
94
+ {'env_id': "Walker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
95
+ {'env_id': "Humanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000},
96
+ ]
97
+ })
98
+
99
+ # Bullet
100
+ _bulletsmall = [
101
+ 'InvertedDoublePendulum', 'InvertedPendulum', 'HalfCheetah', 'Reacher', 'Walker2D', 'Hopper', 'Ant'
102
+ ]
103
+ _bulletsmall = [e + 'BulletEnv-v0' for e in _bulletsmall]
104
+
105
+ register_benchmark({
106
+ 'name': 'Bullet1M',
107
+ 'description': '6 mujoco-like tasks from bullet, 1M steps',
108
+ 'tasks': [{'env_id': e, 'trials': 6, 'num_timesteps': int(1e6)} for e in _bulletsmall]
109
+ })
110
+
111
+
112
+ # Roboschool
113
+
114
+ register_benchmark({
115
+ 'name': 'Roboschool8M',
116
+ 'description': 'Small 2D tasks, up to 30 minutes to complete on 8 cores',
117
+ 'tasks': [
118
+ {'env_id': "RoboschoolReacher-v1", 'trials': 4, 'num_timesteps': 2 * 1000000},
119
+ {'env_id': "RoboschoolAnt-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
120
+ {'env_id': "RoboschoolHalfCheetah-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
121
+ {'env_id': "RoboschoolHopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
122
+ {'env_id': "RoboschoolWalker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
123
+ ]
124
+ })
125
+ register_benchmark({
126
+ 'name': 'RoboschoolHarder',
127
+ 'description': 'Test your might!!! Up to 12 hours on 32 cores',
128
+ 'tasks': [
129
+ {'env_id': "RoboschoolHumanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000},
130
+ {'env_id': "RoboschoolHumanoidFlagrun-v1", 'trials': 4, 'num_timesteps': 200 * 1000000},
131
+ {'env_id': "RoboschoolHumanoidFlagrunHarder-v1", 'trials': 4, 'num_timesteps': 400 * 1000000},
132
+ ]
133
+ })
134
+
135
+ # Other
136
+
137
+ _atari50 = [ # actually 47
138
+ 'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids',
139
+ 'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling',
140
+ 'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber',
141
+ 'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway',
142
+ 'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond',
143
+ 'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman',
144
+ 'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert',
145
+ 'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders', 'StarGunner',
146
+ 'Tennis', 'TimePilot', 'Tutankham', 'UpNDown', 'Venture',
147
+ 'VideoPinball', 'WizardOfWor', 'Zaxxon',
148
+ ]
149
+
150
+ register_benchmark({
151
+ 'name': 'Atari50_10M',
152
+ 'description': '47 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps',
153
+ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atari50]
154
+ })
155
+
156
+ # HER DDPG
157
+
158
+ _fetch_tasks = ['FetchReach-v1', 'FetchPush-v1', 'FetchSlide-v1']
159
+ register_benchmark({
160
+ 'name': 'Fetch1M',
161
+ 'description': 'Fetch* benchmarks for 1M timesteps',
162
+ 'tasks': [{'trials': 6, 'env_id': env_id, 'num_timesteps': int(1e6)} for env_id in _fetch_tasks]
163
+ })
164
+
baselines/bench/monitor.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ['Monitor', 'get_monitor_files', 'load_results']
2
+
3
+ from gym.core import Wrapper
4
+ import time
5
+ from glob import glob
6
+ import csv
7
+ import os.path as osp
8
+ import json
9
+
10
+ class Monitor(Wrapper):
11
+ EXT = "monitor.csv"
12
+ f = None
13
+
14
+ def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
15
+ Wrapper.__init__(self, env=env)
16
+ self.tstart = time.time()
17
+ if filename:
18
+ self.results_writer = ResultsWriter(filename,
19
+ header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
20
+ extra_keys=reset_keywords + info_keywords
21
+ )
22
+ else:
23
+ self.results_writer = None
24
+ self.reset_keywords = reset_keywords
25
+ self.info_keywords = info_keywords
26
+ self.allow_early_resets = allow_early_resets
27
+ self.rewards = None
28
+ self.needs_reset = True
29
+ self.episode_rewards = []
30
+ self.episode_lengths = []
31
+ self.episode_times = []
32
+ self.total_steps = 0
33
+ self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
34
+
35
+ def reset(self, **kwargs):
36
+ self.reset_state()
37
+ for k in self.reset_keywords:
38
+ v = kwargs.get(k)
39
+ if v is None:
40
+ raise ValueError('Expected you to pass kwarg %s into reset'%k)
41
+ self.current_reset_info[k] = v
42
+ return self.env.reset(**kwargs)
43
+
44
+ def reset_state(self):
45
+ if not self.allow_early_resets and not self.needs_reset:
46
+ raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
47
+ self.rewards = []
48
+ self.needs_reset = False
49
+
50
+
51
+ def step(self, action):
52
+ if self.needs_reset:
53
+ raise RuntimeError("Tried to step environment that needs reset")
54
+ ob, rew, done, info = self.env.step(action)
55
+ self.update(ob, rew, done, info)
56
+ return (ob, rew, done, info)
57
+
58
+ def update(self, ob, rew, done, info):
59
+ self.rewards.append(rew)
60
+ if done:
61
+ self.needs_reset = True
62
+ eprew = sum(self.rewards)
63
+ eplen = len(self.rewards)
64
+ epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
65
+ for k in self.info_keywords:
66
+ epinfo[k] = info[k]
67
+ self.episode_rewards.append(eprew)
68
+ self.episode_lengths.append(eplen)
69
+ self.episode_times.append(time.time() - self.tstart)
70
+ epinfo.update(self.current_reset_info)
71
+ if self.results_writer:
72
+ self.results_writer.write_row(epinfo)
73
+ assert isinstance(info, dict)
74
+ if isinstance(info, dict):
75
+ info['episode'] = epinfo
76
+
77
+ self.total_steps += 1
78
+
79
+ def close(self):
80
+ if self.f is not None:
81
+ self.f.close()
82
+
83
+ def get_total_steps(self):
84
+ return self.total_steps
85
+
86
+ def get_episode_rewards(self):
87
+ return self.episode_rewards
88
+
89
+ def get_episode_lengths(self):
90
+ return self.episode_lengths
91
+
92
+ def get_episode_times(self):
93
+ return self.episode_times
94
+
95
+ class LoadMonitorResultsError(Exception):
96
+ pass
97
+
98
+
99
+ class ResultsWriter(object):
100
+ def __init__(self, filename, header='', extra_keys=()):
101
+ self.extra_keys = extra_keys
102
+ assert filename is not None
103
+ if not filename.endswith(Monitor.EXT):
104
+ if osp.isdir(filename):
105
+ filename = osp.join(filename, Monitor.EXT)
106
+ else:
107
+ filename = filename + "." + Monitor.EXT
108
+ self.f = open(filename, "wt")
109
+ if isinstance(header, dict):
110
+ header = '# {} \n'.format(json.dumps(header))
111
+ self.f.write(header)
112
+ self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
113
+ self.logger.writeheader()
114
+ self.f.flush()
115
+
116
+ def write_row(self, epinfo):
117
+ if self.logger:
118
+ self.logger.writerow(epinfo)
119
+ self.f.flush()
120
+
121
+
122
+ def get_monitor_files(dir):
123
+ return glob(osp.join(dir, "*" + Monitor.EXT))
124
+
125
+ def load_results(dir):
126
+ import pandas
127
+ monitor_files = (
128
+ glob(osp.join(dir, "*monitor.json")) +
129
+ glob(osp.join(dir, "*monitor.csv"))) # get both csv and (old) json files
130
+ if not monitor_files:
131
+ raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir))
132
+ dfs = []
133
+ headers = []
134
+ for fname in monitor_files:
135
+ with open(fname, 'rt') as fh:
136
+ if fname.endswith('csv'):
137
+ firstline = fh.readline()
138
+ if not firstline:
139
+ continue
140
+ assert firstline[0] == '#'
141
+ header = json.loads(firstline[1:])
142
+ df = pandas.read_csv(fh, index_col=None)
143
+ headers.append(header)
144
+ elif fname.endswith('json'): # Deprecated json format
145
+ episodes = []
146
+ lines = fh.readlines()
147
+ header = json.loads(lines[0])
148
+ headers.append(header)
149
+ for line in lines[1:]:
150
+ episode = json.loads(line)
151
+ episodes.append(episode)
152
+ df = pandas.DataFrame(episodes)
153
+ else:
154
+ assert 0, 'unreachable'
155
+ df['t'] += header['t_start']
156
+ dfs.append(df)
157
+ df = pandas.concat(dfs)
158
+ df.sort_values('t', inplace=True)
159
+ df.reset_index(inplace=True)
160
+ df['t'] -= min(header['t_start'] for header in headers)
161
+ df.headers = headers # HACK to preserve backwards compatibility
162
+ return df
baselines/bench/test_monitor.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .monitor import Monitor
2
+ import gym
3
+ import json
4
+
5
+ def test_monitor():
6
+ import pandas
7
+ import os
8
+ import uuid
9
+
10
+ env = gym.make("CartPole-v1")
11
+ env.seed(0)
12
+ mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
13
+ menv = Monitor(env, mon_file)
14
+ menv.reset()
15
+ for _ in range(1000):
16
+ _, _, done, _ = menv.step(0)
17
+ if done:
18
+ menv.reset()
19
+
20
+ f = open(mon_file, 'rt')
21
+
22
+ firstline = f.readline()
23
+ assert firstline.startswith('#')
24
+ metadata = json.loads(firstline[1:])
25
+ assert metadata['env_id'] == "CartPole-v1"
26
+ assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"
27
+
28
+ last_logline = pandas.read_csv(f, index_col=None)
29
+ assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
30
+ f.close()
31
+ os.remove(mon_file)
baselines/common/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # flake8: noqa F403
2
+ from baselines.common.console_util import *
3
+ from baselines.common.dataset import Dataset
4
+ from baselines.common.math_util import *
5
+ from baselines.common.misc_util import *
baselines/common/atari_wrappers.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ os.environ.setdefault('PATH', '')
4
+ from collections import deque
5
+ import gym
6
+ from gym import spaces
7
+ import cv2
8
+ cv2.ocl.setUseOpenCL(False)
9
+ from .wrappers import TimeLimit
10
+
11
+
12
+ class NoopResetEnv(gym.Wrapper):
13
+ def __init__(self, env, noop_max=30):
14
+ """Sample initial states by taking random number of no-ops on reset.
15
+ No-op is assumed to be action 0.
16
+ """
17
+ gym.Wrapper.__init__(self, env)
18
+ self.noop_max = noop_max
19
+ self.override_num_noops = None
20
+ self.noop_action = 0
21
+ assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
22
+
23
+ def reset(self, **kwargs):
24
+ """ Do no-op action for a number of steps in [1, noop_max]."""
25
+ self.env.reset(**kwargs)
26
+ if self.override_num_noops is not None:
27
+ noops = self.override_num_noops
28
+ else:
29
+ noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
30
+ assert noops > 0
31
+ obs = None
32
+ for _ in range(noops):
33
+ obs, _, done, _ = self.env.step(self.noop_action)
34
+ if done:
35
+ obs = self.env.reset(**kwargs)
36
+ return obs
37
+
38
+ def step(self, ac):
39
+ return self.env.step(ac)
40
+
41
+ class FireResetEnv(gym.Wrapper):
42
+ def __init__(self, env):
43
+ """Take action on reset for environments that are fixed until firing."""
44
+ gym.Wrapper.__init__(self, env)
45
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
46
+ assert len(env.unwrapped.get_action_meanings()) >= 3
47
+
48
+ def reset(self, **kwargs):
49
+ self.env.reset(**kwargs)
50
+ obs, _, done, _ = self.env.step(1)
51
+ if done:
52
+ self.env.reset(**kwargs)
53
+ obs, _, done, _ = self.env.step(2)
54
+ if done:
55
+ self.env.reset(**kwargs)
56
+ return obs
57
+
58
+ def step(self, ac):
59
+ return self.env.step(ac)
60
+
61
+ class EpisodicLifeEnv(gym.Wrapper):
62
+ def __init__(self, env):
63
+ """Make end-of-life == end-of-episode, but only reset on true game over.
64
+ Done by DeepMind for the DQN and co. since it helps value estimation.
65
+ """
66
+ gym.Wrapper.__init__(self, env)
67
+ self.lives = 0
68
+ self.was_real_done = True
69
+
70
+ def step(self, action):
71
+ obs, reward, done, info = self.env.step(action)
72
+ self.was_real_done = done
73
+ # check current lives, make loss of life terminal,
74
+ # then update lives to handle bonus lives
75
+ lives = self.env.unwrapped.ale.lives()
76
+ if lives < self.lives and lives > 0:
77
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
78
+ # so it's important to keep lives > 0, so that we only reset once
79
+ # the environment advertises done.
80
+ done = True
81
+ self.lives = lives
82
+ return obs, reward, done, info
83
+
84
+ def reset(self, **kwargs):
85
+ """Reset only when lives are exhausted.
86
+ This way all states are still reachable even though lives are episodic,
87
+ and the learner need not know about any of this behind-the-scenes.
88
+ """
89
+ if self.was_real_done:
90
+ obs = self.env.reset(**kwargs)
91
+ else:
92
+ # no-op step to advance from terminal/lost life state
93
+ obs, _, _, _ = self.env.step(0)
94
+ self.lives = self.env.unwrapped.ale.lives()
95
+ return obs
96
+
97
+ class MaxAndSkipEnv(gym.Wrapper):
98
+ def __init__(self, env, skip=4):
99
+ """Return only every `skip`-th frame"""
100
+ gym.Wrapper.__init__(self, env)
101
+ # most recent raw observations (for max pooling across time steps)
102
+ self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
103
+ self._skip = skip
104
+
105
+ def step(self, action):
106
+ """Repeat action, sum reward, and max over last observations."""
107
+ total_reward = 0.0
108
+ done = None
109
+ for i in range(self._skip):
110
+ obs, reward, done, info = self.env.step(action)
111
+ if i == self._skip - 2: self._obs_buffer[0] = obs
112
+ if i == self._skip - 1: self._obs_buffer[1] = obs
113
+ total_reward += reward
114
+ if done:
115
+ break
116
+ # Note that the observation on the done=True frame
117
+ # doesn't matter
118
+ max_frame = self._obs_buffer.max(axis=0)
119
+
120
+ return max_frame, total_reward, done, info
121
+
122
+ def reset(self, **kwargs):
123
+ return self.env.reset(**kwargs)
124
+
125
+ class ClipRewardEnv(gym.RewardWrapper):
126
+ def __init__(self, env):
127
+ gym.RewardWrapper.__init__(self, env)
128
+
129
+ def reward(self, reward):
130
+ """Bin reward to {+1, 0, -1} by its sign."""
131
+ return np.sign(reward)
132
+
133
+
134
+ class WarpFrame(gym.ObservationWrapper):
135
+ def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
136
+ """
137
+ Warp frames to 84x84 as done in the Nature paper and later work.
138
+
139
+ If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
140
+ observation should be warped.
141
+ """
142
+ super().__init__(env)
143
+ self._width = width
144
+ self._height = height
145
+ self._grayscale = grayscale
146
+ self._key = dict_space_key
147
+ if self._grayscale:
148
+ num_colors = 1
149
+ else:
150
+ num_colors = 3
151
+
152
+ new_space = gym.spaces.Box(
153
+ low=0,
154
+ high=255,
155
+ shape=(self._height, self._width, num_colors),
156
+ dtype=np.uint8,
157
+ )
158
+ if self._key is None:
159
+ original_space = self.observation_space
160
+ self.observation_space = new_space
161
+ else:
162
+ original_space = self.observation_space.spaces[self._key]
163
+ self.observation_space.spaces[self._key] = new_space
164
+ assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
165
+
166
+ def observation(self, obs):
167
+ if self._key is None:
168
+ frame = obs
169
+ else:
170
+ frame = obs[self._key]
171
+
172
+ if self._grayscale:
173
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
174
+ frame = cv2.resize(
175
+ frame, (self._width, self._height), interpolation=cv2.INTER_AREA
176
+ )
177
+ if self._grayscale:
178
+ frame = np.expand_dims(frame, -1)
179
+
180
+ if self._key is None:
181
+ obs = frame
182
+ else:
183
+ obs = obs.copy()
184
+ obs[self._key] = frame
185
+ return obs
186
+
187
+
188
+ class FrameStack(gym.Wrapper):
189
+ def __init__(self, env, k):
190
+ """Stack k last frames.
191
+
192
+ Returns lazy array, which is much more memory efficient.
193
+
194
+ See Also
195
+ --------
196
+ baselines.common.atari_wrappers.LazyFrames
197
+ """
198
+ gym.Wrapper.__init__(self, env)
199
+ self.k = k
200
+ self.frames = deque([], maxlen=k)
201
+ shp = env.observation_space.shape
202
+ self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
203
+
204
+ def reset(self):
205
+ ob = self.env.reset()
206
+ for _ in range(self.k):
207
+ self.frames.append(ob)
208
+ return self._get_ob()
209
+
210
+ def step(self, action):
211
+ ob, reward, done, info = self.env.step(action)
212
+ self.frames.append(ob)
213
+ return self._get_ob(), reward, done, info
214
+
215
+ def _get_ob(self):
216
+ assert len(self.frames) == self.k
217
+ return LazyFrames(list(self.frames))
218
+
219
+ class ScaledFloatFrame(gym.ObservationWrapper):
220
+ def __init__(self, env):
221
+ gym.ObservationWrapper.__init__(self, env)
222
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
223
+
224
+ def observation(self, observation):
225
+ # careful! This undoes the memory optimization, use
226
+ # with smaller replay buffers only.
227
+ return np.array(observation).astype(np.float32) / 255.0
228
+
229
+ class LazyFrames(object):
230
+ def __init__(self, frames):
231
+ """This object ensures that common frames between the observations are only stored once.
232
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
233
+ buffers.
234
+
235
+ This object should only be converted to numpy array before being passed to the model.
236
+
237
+ You'd not believe how complex the previous solution was."""
238
+ self._frames = frames
239
+ self._out = None
240
+
241
+ def _force(self):
242
+ if self._out is None:
243
+ self._out = np.concatenate(self._frames, axis=-1)
244
+ self._frames = None
245
+ return self._out
246
+
247
+ def __array__(self, dtype=None):
248
+ out = self._force()
249
+ if dtype is not None:
250
+ out = out.astype(dtype)
251
+ return out
252
+
253
+ def __len__(self):
254
+ return len(self._force())
255
+
256
+ def __getitem__(self, i):
257
+ return self._force()[i]
258
+
259
+ def count(self):
260
+ frames = self._force()
261
+ return frames.shape[frames.ndim - 1]
262
+
263
+ def frame(self, i):
264
+ return self._force()[..., i]
265
+
266
+ def make_atari(env_id, max_episode_steps=None):
267
+ env = gym.make(env_id)
268
+ assert 'NoFrameskip' in env.spec.id
269
+ env = NoopResetEnv(env, noop_max=30)
270
+ env = MaxAndSkipEnv(env, skip=4)
271
+ if max_episode_steps is not None:
272
+ env = TimeLimit(env, max_episode_steps=max_episode_steps)
273
+ return env
274
+
275
+ def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
276
+ """Configure environment for DeepMind-style Atari.
277
+ """
278
+ if episode_life:
279
+ env = EpisodicLifeEnv(env)
280
+ if 'FIRE' in env.unwrapped.get_action_meanings():
281
+ env = FireResetEnv(env)
282
+ env = WarpFrame(env)
283
+ if scale:
284
+ env = ScaledFloatFrame(env)
285
+ if clip_rewards:
286
+ env = ClipRewardEnv(env)
287
+ if frame_stack:
288
+ env = FrameStack(env, 4)
289
+ return env
290
+
baselines/common/cg.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
3
+ """
4
+ Demmel p 312
5
+ """
6
+ p = b.copy()
7
+ r = b.copy()
8
+ x = np.zeros_like(b)
9
+ rdotr = r.dot(r)
10
+
11
+ fmtstr = "%10i %10.3g %10.3g"
12
+ titlestr = "%10s %10s %10s"
13
+ if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
14
+
15
+ for i in range(cg_iters):
16
+ if callback is not None:
17
+ callback(x)
18
+ if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
19
+ z = f_Ax(p)
20
+ v = rdotr / p.dot(z)
21
+ x += v*p
22
+ r -= v*z
23
+ newrdotr = r.dot(r)
24
+ mu = newrdotr/rdotr
25
+ p = r + mu*p
26
+
27
+ rdotr = newrdotr
28
+ if rdotr < residual_tol:
29
+ break
30
+
31
+ if callback is not None:
32
+ callback(x)
33
+ if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
34
+ return x
baselines/common/cmd_util.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for scripts like run_atari.py.
3
+ """
4
+
5
+ import os
6
+ try:
7
+ from mpi4py import MPI
8
+ except ImportError:
9
+ MPI = None
10
+
11
+ import gym
12
+ from gym.wrappers import FlattenObservation, FilterObservation
13
+ from baselines import logger
14
+ from baselines.bench import Monitor
15
+ from baselines.common import set_global_seeds
16
+ from baselines.common.atari_wrappers import make_atari, wrap_deepmind
17
+ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
18
+ from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
19
+ from baselines.common import retro_wrappers
20
+ from baselines.common.wrappers import ClipActionsWrapper
21
+
22
+ def make_vec_env(env_id, env_type, num_env, seed,
23
+ wrapper_kwargs=None,
24
+ env_kwargs=None,
25
+ start_index=0,
26
+ reward_scale=1.0,
27
+ flatten_dict_observations=True,
28
+ gamestate=None,
29
+ initializer=None,
30
+ force_dummy=False):
31
+ """
32
+ Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
33
+ """
34
+ wrapper_kwargs = wrapper_kwargs or {}
35
+ env_kwargs = env_kwargs or {}
36
+ mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
37
+ seed = seed + 10000 * mpi_rank if seed is not None else None
38
+ logger_dir = logger.get_dir()
39
+ def make_thunk(rank, initializer=None):
40
+ return lambda: make_env(
41
+ env_id=env_id,
42
+ env_type=env_type,
43
+ mpi_rank=mpi_rank,
44
+ subrank=rank,
45
+ seed=seed,
46
+ reward_scale=reward_scale,
47
+ gamestate=gamestate,
48
+ flatten_dict_observations=flatten_dict_observations,
49
+ wrapper_kwargs=wrapper_kwargs,
50
+ env_kwargs=env_kwargs,
51
+ logger_dir=logger_dir,
52
+ initializer=initializer
53
+ )
54
+
55
+ set_global_seeds(seed)
56
+ if not force_dummy and num_env > 1:
57
+ return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
58
+ else:
59
+ return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
60
+
61
+
62
+ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
63
+ if initializer is not None:
64
+ initializer(mpi_rank=mpi_rank, subrank=subrank)
65
+
66
+ wrapper_kwargs = wrapper_kwargs or {}
67
+ env_kwargs = env_kwargs or {}
68
+ if ':' in env_id:
69
+ import re
70
+ import importlib
71
+ module_name = re.sub(':.*','',env_id)
72
+ env_id = re.sub('.*:', '', env_id)
73
+ importlib.import_module(module_name)
74
+ if env_type == 'atari':
75
+ env = make_atari(env_id)
76
+ elif env_type == 'retro':
77
+ import retro
78
+ gamestate = gamestate or retro.State.DEFAULT
79
+ env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
80
+ else:
81
+ env = gym.make(env_id, **env_kwargs)
82
+
83
+ if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
84
+ env = FlattenObservation(env)
85
+
86
+ env.seed(seed + subrank if seed is not None else None)
87
+ env = Monitor(env,
88
+ logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
89
+ allow_early_resets=True)
90
+
91
+
92
+ if env_type == 'atari':
93
+ env = wrap_deepmind(env, **wrapper_kwargs)
94
+ elif env_type == 'retro':
95
+ if 'frame_stack' not in wrapper_kwargs:
96
+ wrapper_kwargs['frame_stack'] = 1
97
+ env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
98
+
99
+ if isinstance(env.action_space, gym.spaces.Box):
100
+ env = ClipActionsWrapper(env)
101
+
102
+ if reward_scale != 1:
103
+ env = retro_wrappers.RewardScaler(env, reward_scale)
104
+
105
+ return env
106
+
107
+
108
+ def make_mujoco_env(env_id, seed, reward_scale=1.0):
109
+ """
110
+ Create a wrapped, monitored gym.Env for MuJoCo.
111
+ """
112
+ rank = MPI.COMM_WORLD.Get_rank()
113
+ myseed = seed + 1000 * rank if seed is not None else None
114
+ set_global_seeds(myseed)
115
+ env = gym.make(env_id)
116
+ logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
117
+ env = Monitor(env, logger_path, allow_early_resets=True)
118
+ env.seed(seed)
119
+ if reward_scale != 1.0:
120
+ from baselines.common.retro_wrappers import RewardScaler
121
+ env = RewardScaler(env, reward_scale)
122
+ return env
123
+
124
+ def make_robotics_env(env_id, seed, rank=0):
125
+ """
126
+ Create a wrapped, monitored gym.Env for MuJoCo.
127
+ """
128
+ set_global_seeds(seed)
129
+ env = gym.make(env_id)
130
+ env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
131
+ env = Monitor(
132
+ env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
133
+ info_keywords=('is_success',))
134
+ env.seed(seed)
135
+ return env
136
+
137
+ def arg_parser():
138
+ """
139
+ Create an empty argparse.ArgumentParser.
140
+ """
141
+ import argparse
142
+ return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
143
+
144
+ def atari_arg_parser():
145
+ """
146
+ Create an argparse.ArgumentParser for run_atari.py.
147
+ """
148
+ print('Obsolete - use common_arg_parser instead')
149
+ return common_arg_parser()
150
+
151
+ def mujoco_arg_parser():
152
+ print('Obsolete - use common_arg_parser instead')
153
+ return common_arg_parser()
154
+
155
+ def common_arg_parser():
156
+ """
157
+ Create an argparse.ArgumentParser for run_mujoco.py.
158
+ """
159
+ parser = arg_parser()
160
+ parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
161
+ parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
162
+ parser.add_argument('--seed', help='RNG seed', type=int, default=None)
163
+ parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
164
+ parser.add_argument('--num_timesteps', type=float, default=1e6),
165
+ parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
166
+ parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
167
+ parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
168
+ parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
169
+ parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
170
+ parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
171
+ parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
172
+ parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
173
+ parser.add_argument('--play', default=False, action='store_true')
174
+ return parser
175
+
176
+ def robotics_arg_parser():
177
+ """
178
+ Create an argparse.ArgumentParser for run_mujoco.py.
179
+ """
180
+ parser = arg_parser()
181
+ parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
182
+ parser.add_argument('--seed', help='RNG seed', type=int, default=None)
183
+ parser.add_argument('--num-timesteps', type=int, default=int(1e6))
184
+ return parser
185
+
186
+
187
+ def parse_unknown_args(args):
188
+ """
189
+ Parse arguments not consumed by arg parser into a dictionary
190
+ """
191
+ retval = {}
192
+ preceded_by_key = False
193
+ for arg in args:
194
+ if arg.startswith('--'):
195
+ if '=' in arg:
196
+ key = arg.split('=')[0][2:]
197
+ value = arg.split('=')[1]
198
+ retval[key] = value
199
+ else:
200
+ key = arg[2:]
201
+ preceded_by_key = True
202
+ elif preceded_by_key:
203
+ retval[key] = arg
204
+ preceded_by_key = False
205
+
206
+ return retval
baselines/common/console_util.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from contextlib import contextmanager
3
+ import numpy as np
4
+ import time
5
+ import shlex
6
+ import subprocess
7
+
8
+ # ================================================================
9
+ # Misc
10
+ # ================================================================
11
+
12
+ def fmt_row(width, row, header=False):
13
+ out = " | ".join(fmt_item(x, width) for x in row)
14
+ if header: out = out + "\n" + "-"*len(out)
15
+ return out
16
+
17
+ def fmt_item(x, l):
18
+ if isinstance(x, np.ndarray):
19
+ assert x.ndim==0
20
+ x = x.item()
21
+ if isinstance(x, (float, np.float32, np.float64)):
22
+ v = abs(x)
23
+ if (v < 1e-4 or v > 1e+4) and v > 0:
24
+ rep = "%7.2e" % x
25
+ else:
26
+ rep = "%7.5f" % x
27
+ else: rep = str(x)
28
+ return " "*(l - len(rep)) + rep
29
+
30
+ color2num = dict(
31
+ gray=30,
32
+ red=31,
33
+ green=32,
34
+ yellow=33,
35
+ blue=34,
36
+ magenta=35,
37
+ cyan=36,
38
+ white=37,
39
+ crimson=38
40
+ )
41
+
42
+ def colorize(string, color='green', bold=False, highlight=False):
43
+ attr = []
44
+ num = color2num[color]
45
+ if highlight: num += 10
46
+ attr.append(str(num))
47
+ if bold: attr.append('1')
48
+ return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
49
+
50
+ def print_cmd(cmd, dry=False):
51
+ if isinstance(cmd, str): # for shell=True
52
+ pass
53
+ else:
54
+ cmd = ' '.join(shlex.quote(arg) for arg in cmd)
55
+ print(colorize(('CMD: ' if not dry else 'DRY: ') + cmd))
56
+
57
+
58
+ def get_git_commit(cwd=None):
59
+ return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=cwd).decode('utf8')
60
+
61
+ def get_git_commit_message(cwd=None):
62
+ return subprocess.check_output(['git', 'show', '-s', '--format=%B', 'HEAD'], cwd=cwd).decode('utf8')
63
+
64
+ def ccap(cmd, dry=False, env=None, **kwargs):
65
+ print_cmd(cmd, dry)
66
+ if not dry:
67
+ subprocess.check_call(cmd, env=env, **kwargs)
68
+
69
+
70
+ MESSAGE_DEPTH = 0
71
+
72
+ @contextmanager
73
+ def timed(msg):
74
+ global MESSAGE_DEPTH #pylint: disable=W0603
75
+ print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
76
+ tstart = time.time()
77
+ MESSAGE_DEPTH += 1
78
+ yield
79
+ MESSAGE_DEPTH -= 1
80
+ print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))
baselines/common/dataset.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Dataset(object):
4
+ def __init__(self, data_map, deterministic=False, shuffle=True):
5
+ self.data_map = data_map
6
+ self.deterministic = deterministic
7
+ self.enable_shuffle = shuffle
8
+ self.n = next(iter(data_map.values())).shape[0]
9
+ self._next_id = 0
10
+ self.shuffle()
11
+
12
+ def shuffle(self):
13
+ if self.deterministic:
14
+ return
15
+ perm = np.arange(self.n)
16
+ np.random.shuffle(perm)
17
+
18
+ for key in self.data_map:
19
+ self.data_map[key] = self.data_map[key][perm]
20
+
21
+ self._next_id = 0
22
+
23
+ def next_batch(self, batch_size):
24
+ if self._next_id >= self.n and self.enable_shuffle:
25
+ self.shuffle()
26
+
27
+ cur_id = self._next_id
28
+ cur_batch_size = min(batch_size, self.n - self._next_id)
29
+ self._next_id += cur_batch_size
30
+
31
+ data_map = dict()
32
+ for key in self.data_map:
33
+ data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
34
+ return data_map
35
+
36
+ def iterate_once(self, batch_size):
37
+ if self.enable_shuffle: self.shuffle()
38
+
39
+ while self._next_id <= self.n - batch_size:
40
+ yield self.next_batch(batch_size)
41
+ self._next_id = 0
42
+
43
+ def subset(self, num_elements, deterministic=True):
44
+ data_map = dict()
45
+ for key in self.data_map:
46
+ data_map[key] = self.data_map[key][:num_elements]
47
+ return Dataset(data_map, deterministic)
48
+
49
+
50
+ def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
51
+ assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
52
+ arrays = tuple(map(np.asarray, arrays))
53
+ n = arrays[0].shape[0]
54
+ assert all(a.shape[0] == n for a in arrays[1:])
55
+ inds = np.arange(n)
56
+ if shuffle: np.random.shuffle(inds)
57
+ sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
58
+ for batch_inds in np.array_split(inds, sections):
59
+ if include_final_partial_batch or len(batch_inds) == batch_size:
60
+ yield tuple(a[batch_inds] for a in arrays)
baselines/common/distributions.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import baselines.common.tf_util as U
4
+ from baselines.a2c.utils import fc
5
+ from tensorflow.python.ops import math_ops
6
+
7
+ class Pd(object):
8
+ """
9
+ A particular probability distribution
10
+ """
11
+ def flatparam(self):
12
+ raise NotImplementedError
13
+ def mode(self):
14
+ raise NotImplementedError
15
+ def neglogp(self, x):
16
+ # Usually it's easier to define the negative logprob
17
+ raise NotImplementedError
18
+ def kl(self, other):
19
+ raise NotImplementedError
20
+ def entropy(self):
21
+ raise NotImplementedError
22
+ def sample(self):
23
+ raise NotImplementedError
24
+ def logp(self, x):
25
+ return - self.neglogp(x)
26
+ def get_shape(self):
27
+ return self.flatparam().shape
28
+ @property
29
+ def shape(self):
30
+ return self.get_shape()
31
+ def __getitem__(self, idx):
32
+ return self.__class__(self.flatparam()[idx])
33
+
34
+ class PdType(object):
35
+ """
36
+ Parametrized family of probability distributions
37
+ """
38
+ def pdclass(self):
39
+ raise NotImplementedError
40
+ def pdfromflat(self, flat):
41
+ return self.pdclass()(flat)
42
+ def pdfromlatent(self, latent_vector, init_scale, init_bias):
43
+ raise NotImplementedError
44
+ def param_shape(self):
45
+ raise NotImplementedError
46
+ def sample_shape(self):
47
+ raise NotImplementedError
48
+ def sample_dtype(self):
49
+ raise NotImplementedError
50
+
51
+ def param_placeholder(self, prepend_shape, name=None):
52
+ return tf.compat.v1.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
53
+ def sample_placeholder(self, prepend_shape, name=None):
54
+ return tf.compat.v1.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
55
+
56
+ def __eq__(self, other):
57
+ return (type(self) == type(other)) and (self.__dict__ == other.__dict__)
58
+
59
+ class CategoricalPdType(PdType):
60
+ def __init__(self, ncat):
61
+ self.ncat = ncat
62
+ def pdclass(self):
63
+ return CategoricalPd
64
+ def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
65
+ pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
66
+ return self.pdfromflat(pdparam), pdparam
67
+
68
+ def param_shape(self):
69
+ return [self.ncat]
70
+ def sample_shape(self):
71
+ return []
72
+ def sample_dtype(self):
73
+ return tf.int32
74
+
75
+
76
+ class MultiCategoricalPdType(PdType):
77
+ def __init__(self, nvec):
78
+ self.ncats = nvec.astype('int32')
79
+ assert (self.ncats > 0).all()
80
+ def pdclass(self):
81
+ return MultiCategoricalPd
82
+ def pdfromflat(self, flat):
83
+ return MultiCategoricalPd(self.ncats, flat)
84
+
85
+ def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
86
+ pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
87
+ return self.pdfromflat(pdparam), pdparam
88
+
89
+ def param_shape(self):
90
+ return [sum(self.ncats)]
91
+ def sample_shape(self):
92
+ return [len(self.ncats)]
93
+ def sample_dtype(self):
94
+ return tf.int32
95
+
96
+ class DiagGaussianPdType(PdType):
97
+ def __init__(self, size):
98
+ self.size = size
99
+ def pdclass(self):
100
+ return DiagGaussianPd
101
+
102
+ def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
103
+ mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
104
+ logstd = tf.compat.v1.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.compat.v1.zeros_initializer())
105
+ pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
106
+ return self.pdfromflat(pdparam), mean
107
+
108
+ def param_shape(self):
109
+ return [2*self.size]
110
+ def sample_shape(self):
111
+ return [self.size]
112
+ def sample_dtype(self):
113
+ return tf.float32
114
+
115
+ class BernoulliPdType(PdType):
116
+ def __init__(self, size):
117
+ self.size = size
118
+ def pdclass(self):
119
+ return BernoulliPd
120
+ def param_shape(self):
121
+ return [self.size]
122
+ def sample_shape(self):
123
+ return [self.size]
124
+ def sample_dtype(self):
125
+ return tf.int32
126
+ def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
127
+ pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
128
+ return self.pdfromflat(pdparam), pdparam
129
+
130
+ # WRONG SECOND DERIVATIVES
131
+ # class CategoricalPd(Pd):
132
+ # def __init__(self, logits):
133
+ # self.logits = logits
134
+ # self.ps = tf.nn.softmax(logits)
135
+ # @classmethod
136
+ # def fromflat(cls, flat):
137
+ # return cls(flat)
138
+ # def flatparam(self):
139
+ # return self.logits
140
+ # def mode(self):
141
+ # return U.argmax(self.logits, axis=-1)
142
+ # def logp(self, x):
143
+ # return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
144
+ # def kl(self, other):
145
+ # return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
146
+ # - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
147
+ # def entropy(self):
148
+ # return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
149
+ # def sample(self):
150
+ # u = tf.random_uniform(tf.shape(self.logits))
151
+ # return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
152
+
153
+ class CategoricalPd(Pd):
154
+ def __init__(self, logits):
155
+ self.logits = logits
156
+ def flatparam(self):
157
+ return self.logits
158
+ def mode(self):
159
+ return tf.argmax(input=self.logits, axis=-1)
160
+
161
+ @property
162
+ def mean(self):
163
+ return tf.nn.softmax(self.logits)
164
+ def neglogp(self, x):
165
+ # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
166
+ # Note: we can't use sparse_softmax_cross_entropy_with_logits because
167
+ # the implementation does not allow second-order derivatives...
168
+ if x.dtype in {tf.uint8, tf.int32, tf.int64}:
169
+ # one-hot encoding
170
+ x_shape_list = x.shape.as_list()
171
+ logits_shape_list = self.logits.get_shape().as_list()[:-1]
172
+ for xs, ls in zip(x_shape_list, logits_shape_list):
173
+ if xs is not None and ls is not None:
174
+ assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
175
+
176
+ x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
177
+ else:
178
+ # already encoded
179
+ assert x.shape.as_list() == self.logits.shape.as_list()
180
+
181
+ return tf.nn.softmax_cross_entropy_with_logits(
182
+ logits=self.logits,
183
+ labels=x)
184
+ def kl(self, other):
185
+ a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True)
186
+ a1 = other.logits - tf.reduce_max(input_tensor=other.logits, axis=-1, keepdims=True)
187
+ ea0 = tf.exp(a0)
188
+ ea1 = tf.exp(a1)
189
+ z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True)
190
+ z1 = tf.reduce_sum(input_tensor=ea1, axis=-1, keepdims=True)
191
+ p0 = ea0 / z0
192
+ return tf.reduce_sum(input_tensor=p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=-1)
193
+ def entropy(self):
194
+ a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True)
195
+ ea0 = tf.exp(a0)
196
+ z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True)
197
+ p0 = ea0 / z0
198
+ return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=-1)
199
+ def sample(self):
200
+ u = tf.random.uniform(tf.shape(input=self.logits), dtype=self.logits.dtype)
201
+ return tf.argmax(input=self.logits - tf.math.log(-tf.math.log(u)), axis=-1)
202
+ @classmethod
203
+ def fromflat(cls, flat):
204
+ return cls(flat)
205
+
206
+ class MultiCategoricalPd(Pd):
207
+ def __init__(self, nvec, flat):
208
+ self.flat = flat
209
+ self.categoricals = list(map(CategoricalPd,
210
+ tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
211
+ def flatparam(self):
212
+ return self.flat
213
+ def mode(self):
214
+ return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
215
+ def neglogp(self, x):
216
+ return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
217
+ def kl(self, other):
218
+ return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
219
+ def entropy(self):
220
+ return tf.add_n([p.entropy() for p in self.categoricals])
221
+ def sample(self):
222
+ return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
223
+ @classmethod
224
+ def fromflat(cls, flat):
225
+ raise NotImplementedError
226
+
227
+ class DiagGaussianPd(Pd):
228
+ def __init__(self, flat):
229
+ self.flat = flat
230
+ mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
231
+ self.mean = mean
232
+ self.logstd = logstd
233
+ self.std = tf.exp(logstd)
234
+ def flatparam(self):
235
+ return self.flat
236
+ def mode(self):
237
+ return self.mean
238
+ def neglogp(self, x):
239
+ return 0.5 * tf.reduce_sum(input_tensor=tf.square((x - self.mean) / self.std), axis=-1) \
240
+ + 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(input=x)[-1], dtype=tf.float32) \
241
+ + tf.reduce_sum(input_tensor=self.logstd, axis=-1)
242
+ def kl(self, other):
243
+ assert isinstance(other, DiagGaussianPd)
244
+ return tf.reduce_sum(input_tensor=other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
245
+ def entropy(self):
246
+ return tf.reduce_sum(input_tensor=self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
247
+ def sample(self):
248
+ return self.mean + self.std * tf.random.normal(tf.shape(input=self.mean))
249
+ @classmethod
250
+ def fromflat(cls, flat):
251
+ return cls(flat)
252
+
253
+
254
+ class BernoulliPd(Pd):
255
+ def __init__(self, logits):
256
+ self.logits = logits
257
+ self.ps = tf.sigmoid(logits)
258
+ def flatparam(self):
259
+ return self.logits
260
+ @property
261
+ def mean(self):
262
+ return self.ps
263
+ def mode(self):
264
+ return tf.round(self.ps)
265
+ def neglogp(self, x):
266
+ return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.cast(x, dtype=tf.float32)), axis=-1)
267
+ def kl(self, other):
268
+ return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
269
+ def entropy(self):
270
+ return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
271
+ def sample(self):
272
+ u = tf.random.uniform(tf.shape(input=self.ps))
273
+ return tf.cast(math_ops.less(u, self.ps), dtype=tf.float32)
274
+ @classmethod
275
+ def fromflat(cls, flat):
276
+ return cls(flat)
277
+
278
+ def make_pdtype(ac_space):
279
+ from gym import spaces
280
+ if isinstance(ac_space, spaces.Box):
281
+ assert len(ac_space.shape) == 1
282
+ return DiagGaussianPdType(ac_space.shape[0])
283
+ elif isinstance(ac_space, spaces.Discrete):
284
+ return CategoricalPdType(ac_space.n)
285
+ elif isinstance(ac_space, spaces.MultiDiscrete):
286
+ return MultiCategoricalPdType(ac_space.nvec)
287
+ elif isinstance(ac_space, spaces.MultiBinary):
288
+ return BernoulliPdType(ac_space.n)
289
+ else:
290
+ raise NotImplementedError
291
+
292
+ def shape_el(v, i):
293
+ maybe = v.get_shape()[i]
294
+ if maybe is not None:
295
+ return maybe
296
+ else:
297
+ return tf.shape(input=v)[i]
298
+
299
+ @U.in_session
300
+ def test_probtypes():
301
+ np.random.seed(0)
302
+
303
+ pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
304
+ diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
305
+ validate_probtype(diag_gauss, pdparam_diag_gauss)
306
+
307
+ pdparam_categorical = np.array([-.2, .3, .5])
308
+ categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
309
+ validate_probtype(categorical, pdparam_categorical)
310
+
311
+ nvec = [1,2,3]
312
+ pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
313
+ multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
314
+ validate_probtype(multicategorical, pdparam_multicategorical)
315
+
316
+ pdparam_bernoulli = np.array([-.2, .3, .5])
317
+ bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
318
+ validate_probtype(bernoulli, pdparam_bernoulli)
319
+
320
+
321
+ def validate_probtype(probtype, pdparam):
322
+ N = 100000
323
+ # Check to see if mean negative log likelihood == differential entropy
324
+ Mval = np.repeat(pdparam[None, :], N, axis=0)
325
+ M = probtype.param_placeholder([N])
326
+ X = probtype.sample_placeholder([N])
327
+ pd = probtype.pdfromflat(M)
328
+ calcloglik = U.function([X, M], pd.logp(X))
329
+ calcent = U.function([M], pd.entropy())
330
+ Xval = tf.compat.v1.get_default_session().run(pd.sample(), feed_dict={M:Mval})
331
+ logliks = calcloglik(Xval, Mval)
332
+ entval_ll = - logliks.mean() #pylint: disable=E1101
333
+ entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
334
+ entval = calcent(Mval).mean() #pylint: disable=E1101
335
+ assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
336
+
337
+ # Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
338
+ M2 = probtype.param_placeholder([N])
339
+ pd2 = probtype.pdfromflat(M2)
340
+ q = pdparam + np.random.randn(pdparam.size) * 0.1
341
+ Mval2 = np.repeat(q[None, :], N, axis=0)
342
+ calckl = U.function([M, M2], pd.kl(pd2))
343
+ klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
344
+ logliks = calcloglik(Xval, Mval2)
345
+ klval_ll = - entval - logliks.mean() #pylint: disable=E1101
346
+ klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
347
+ assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
348
+ print('ok on', probtype, pdparam)
349
+
350
+
351
+ def _matching_fc(tensor, name, size, init_scale, init_bias):
352
+ if tensor.shape[-1] == size:
353
+ return tensor
354
+ else:
355
+ return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
baselines/common/input.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ tf.compat.v1.disable_eager_execution()
4
+ from gym.spaces import Discrete, Box, MultiDiscrete
5
+
6
+ def observation_placeholder(ob_space, batch_size=None, name='Ob'):
7
+ '''
8
+ Create placeholder to feed observations into of the size appropriate to the observation space
9
+
10
+ Parameters:
11
+ ----------
12
+
13
+ ob_space: gym.Space observation space
14
+
15
+ batch_size: int size of the batch to be fed into input. Can be left None in most cases.
16
+
17
+ name: str name of the placeholder
18
+
19
+ Returns:
20
+ -------
21
+
22
+ tensorflow placeholder tensor
23
+ '''
24
+
25
+ assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
26
+ 'Can only deal with Discrete and Box observation spaces for now'
27
+
28
+ dtype = ob_space.dtype
29
+ if dtype == np.int8:
30
+ dtype = np.uint8
31
+
32
+ return tf.compat.v1.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
33
+
34
+
35
+ def observation_input(ob_space, batch_size=None, name='Ob'):
36
+ '''
37
+ Create placeholder to feed observations into of the size appropriate to the observation space, and add input
38
+ encoder of the appropriate type.
39
+ '''
40
+
41
+ placeholder = observation_placeholder(ob_space, batch_size, name)
42
+ return placeholder, encode_observation(ob_space, placeholder)
43
+
44
+ def encode_observation(ob_space, placeholder):
45
+ '''
46
+ Encode input in the way that is appropriate to the observation space
47
+
48
+ Parameters:
49
+ ----------
50
+
51
+ ob_space: gym.Space observation space
52
+
53
+ placeholder: tf.placeholder observation input placeholder
54
+ '''
55
+ if isinstance(ob_space, Discrete):
56
+ return tf.cast(tf.one_hot(placeholder, ob_space.n), dtype=tf.float32)
57
+ elif isinstance(ob_space, Box):
58
+ return tf.cast(placeholder, dtype=tf.float32)
59
+ elif isinstance(ob_space, MultiDiscrete):
60
+ placeholder = tf.cast(placeholder, tf.int32)
61
+ one_hots = [tf.cast(tf.one_hot(placeholder[..., i], ob_space.nvec[i]), dtype=tf.float32) for i in range(placeholder.shape[-1])]
62
+ return tf.concat(one_hots, axis=-1)
63
+ else:
64
+ raise NotImplementedError
65
+
baselines/common/math_util.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.signal
3
+
4
+
5
+ def discount(x, gamma):
6
+ """
7
+ computes discounted sums along 0th dimension of x.
8
+
9
+ inputs
10
+ ------
11
+ x: ndarray
12
+ gamma: float
13
+
14
+ outputs
15
+ -------
16
+ y: ndarray with same shape as x, satisfying
17
+
18
+ y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
19
+ where k = len(x) - t - 1
20
+
21
+ """
22
+ assert x.ndim >= 1
23
+ return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1]
24
+
25
+ def explained_variance(ypred,y):
26
+ """
27
+ Computes fraction of variance that ypred explains about y.
28
+ Returns 1 - Var[y-ypred] / Var[y]
29
+
30
+ interpretation:
31
+ ev=0 => might as well have predicted zero
32
+ ev=1 => perfect prediction
33
+ ev<0 => worse than just predicting zero
34
+
35
+ """
36
+ assert y.ndim == 1 and ypred.ndim == 1
37
+ vary = np.var(y)
38
+ return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
39
+
40
+ def explained_variance_2d(ypred, y):
41
+ assert y.ndim == 2 and ypred.ndim == 2
42
+ vary = np.var(y, axis=0)
43
+ out = 1 - np.var(y-ypred)/vary
44
+ out[vary < 1e-10] = 0
45
+ return out
46
+
47
+ def ncc(ypred, y):
48
+ return np.corrcoef(ypred, y)[1,0]
49
+
50
+ def flatten_arrays(arrs):
51
+ return np.concatenate([arr.flat for arr in arrs])
52
+
53
+ def unflatten_vector(vec, shapes):
54
+ i=0
55
+ arrs = []
56
+ for shape in shapes:
57
+ size = np.prod(shape)
58
+ arr = vec[i:i+size].reshape(shape)
59
+ arrs.append(arr)
60
+ i += size
61
+ return arrs
62
+
63
+ def discount_with_boundaries(X, New, gamma):
64
+ """
65
+ X: 2d array of floats, time x features
66
+ New: 2d array of bools, indicating when a new episode has started
67
+ """
68
+ Y = np.zeros_like(X)
69
+ T = X.shape[0]
70
+ Y[T-1] = X[T-1]
71
+ for t in range(T-2, -1, -1):
72
+ Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1])
73
+ return Y
74
+
75
+ def test_discount_with_boundaries():
76
+ gamma=0.9
77
+ x = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
78
+ starts = [1.0, 0.0, 0.0, 1.0]
79
+ y = discount_with_boundaries(x, starts, gamma)
80
+ assert np.allclose(y, [
81
+ 1 + gamma * 2 + gamma**2 * 3,
82
+ 2 + gamma * 3,
83
+ 3,
84
+ 4
85
+ ])
baselines/common/misc_util.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ import random
6
+ import tempfile
7
+ import zipfile
8
+
9
+
10
+ def zipsame(*seqs):
11
+ L = len(seqs[0])
12
+ assert all(len(seq) == L for seq in seqs[1:])
13
+ return zip(*seqs)
14
+
15
+
16
+ class EzPickle(object):
17
+ """Objects that are pickled and unpickled via their constructor
18
+ arguments.
19
+
20
+ Example usage:
21
+
22
+ class Dog(Animal, EzPickle):
23
+ def __init__(self, furcolor, tailkind="bushy"):
24
+ Animal.__init__()
25
+ EzPickle.__init__(furcolor, tailkind)
26
+ ...
27
+
28
+ When this object is unpickled, a new Dog will be constructed by passing the provided
29
+ furcolor and tailkind into the constructor. However, philosophers are still not sure
30
+ whether it is still the same dog.
31
+
32
+ This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
33
+ and Atari.
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ self._ezpickle_args = args
38
+ self._ezpickle_kwargs = kwargs
39
+
40
+ def __getstate__(self):
41
+ return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
42
+
43
+ def __setstate__(self, d):
44
+ out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
45
+ self.__dict__.update(out.__dict__)
46
+
47
+
48
+ def set_global_seeds(i):
49
+ try:
50
+ import MPI
51
+ rank = MPI.COMM_WORLD.Get_rank()
52
+ except ImportError:
53
+ rank = 0
54
+
55
+ myseed = i + 1000 * rank if i is not None else None
56
+ try:
57
+ import tensorflow as tf
58
+ tf.compat.v1.set_random_seed(myseed)
59
+ except ImportError:
60
+ pass
61
+ np.random.seed(myseed)
62
+ random.seed(myseed)
63
+
64
+
65
+ def pretty_eta(seconds_left):
66
+ """Print the number of seconds in human readable format.
67
+
68
+ Examples:
69
+ 2 days
70
+ 2 hours and 37 minutes
71
+ less than a minute
72
+
73
+ Paramters
74
+ ---------
75
+ seconds_left: int
76
+ Number of seconds to be converted to the ETA
77
+ Returns
78
+ -------
79
+ eta: str
80
+ String representing the pretty ETA.
81
+ """
82
+ minutes_left = seconds_left // 60
83
+ seconds_left %= 60
84
+ hours_left = minutes_left // 60
85
+ minutes_left %= 60
86
+ days_left = hours_left // 24
87
+ hours_left %= 24
88
+
89
+ def helper(cnt, name):
90
+ return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
91
+
92
+ if days_left > 0:
93
+ msg = helper(days_left, 'day')
94
+ if hours_left > 0:
95
+ msg += ' and ' + helper(hours_left, 'hour')
96
+ return msg
97
+ if hours_left > 0:
98
+ msg = helper(hours_left, 'hour')
99
+ if minutes_left > 0:
100
+ msg += ' and ' + helper(minutes_left, 'minute')
101
+ return msg
102
+ if minutes_left > 0:
103
+ return helper(minutes_left, 'minute')
104
+ return 'less than a minute'
105
+
106
+
107
+ class RunningAvg(object):
108
+ def __init__(self, gamma, init_value=None):
109
+ """Keep a running estimate of a quantity. This is a bit like mean
110
+ but more sensitive to recent changes.
111
+
112
+ Parameters
113
+ ----------
114
+ gamma: float
115
+ Must be between 0 and 1, where 0 is the most sensitive to recent
116
+ changes.
117
+ init_value: float or None
118
+ Initial value of the estimate. If None, it will be set on the first update.
119
+ """
120
+ self._value = init_value
121
+ self._gamma = gamma
122
+
123
+ def update(self, new_val):
124
+ """Update the estimate.
125
+
126
+ Parameters
127
+ ----------
128
+ new_val: float
129
+ new observated value of estimated quantity.
130
+ """
131
+ if self._value is None:
132
+ self._value = new_val
133
+ else:
134
+ self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
135
+
136
+ def __float__(self):
137
+ """Get the current estimate"""
138
+ return self._value
139
+
140
+ def boolean_flag(parser, name, default=False, help=None):
141
+ """Add a boolean flag to argparse parser.
142
+
143
+ Parameters
144
+ ----------
145
+ parser: argparse.Parser
146
+ parser to add the flag to
147
+ name: str
148
+ --<name> will enable the flag, while --no-<name> will disable it
149
+ default: bool or None
150
+ default value of the flag
151
+ help: str
152
+ help string for the flag
153
+ """
154
+ dest = name.replace('-', '_')
155
+ parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
156
+ parser.add_argument("--no-" + name, action="store_false", dest=dest)
157
+
158
+
159
+ def get_wrapper_by_name(env, classname):
160
+ """Given an a gym environment possibly wrapped multiple times, returns a wrapper
161
+ of class named classname or raises ValueError if no such wrapper was applied
162
+
163
+ Parameters
164
+ ----------
165
+ env: gym.Env of gym.Wrapper
166
+ gym environment
167
+ classname: str
168
+ name of the wrapper
169
+
170
+ Returns
171
+ -------
172
+ wrapper: gym.Wrapper
173
+ wrapper named classname
174
+ """
175
+ currentenv = env
176
+ while True:
177
+ if classname == currentenv.class_name():
178
+ return currentenv
179
+ elif isinstance(currentenv, gym.Wrapper):
180
+ currentenv = currentenv.env
181
+ else:
182
+ raise ValueError("Couldn't find wrapper named %s" % classname)
183
+
184
+
185
+ def relatively_safe_pickle_dump(obj, path, compression=False):
186
+ """This is just like regular pickle dump, except from the fact that failure cases are
187
+ different:
188
+
189
+ - It's never possible that we end up with a pickle in corrupted state.
190
+ - If a there was a different file at the path, that file will remain unchanged in the
191
+ even of failure (provided that filesystem rename is atomic).
192
+ - it is sometimes possible that we end up with useless temp file which needs to be
193
+ deleted manually (it will be removed automatically on the next function call)
194
+
195
+ The indended use case is periodic checkpoints of experiment state, such that we never
196
+ corrupt previous checkpoints if the current one fails.
197
+
198
+ Parameters
199
+ ----------
200
+ obj: object
201
+ object to pickle
202
+ path: str
203
+ path to the output file
204
+ compression: bool
205
+ if true pickle will be compressed
206
+ """
207
+ temp_storage = path + ".relatively_safe"
208
+ if compression:
209
+ # Using gzip here would be simpler, but the size is limited to 2GB
210
+ with tempfile.NamedTemporaryFile() as uncompressed_file:
211
+ pickle.dump(obj, uncompressed_file)
212
+ uncompressed_file.file.flush()
213
+ with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
214
+ myzip.write(uncompressed_file.name, "data")
215
+ else:
216
+ with open(temp_storage, "wb") as f:
217
+ pickle.dump(obj, f)
218
+ os.rename(temp_storage, path)
219
+
220
+
221
+ def pickle_load(path, compression=False):
222
+ """Unpickle a possible compressed pickle.
223
+
224
+ Parameters
225
+ ----------
226
+ path: str
227
+ path to the output file
228
+ compression: bool
229
+ if true assumes that pickle was compressed when created and attempts decompression.
230
+
231
+ Returns
232
+ -------
233
+ obj: object
234
+ the unpickled object
235
+ """
236
+
237
+ if compression:
238
+ with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
239
+ with myzip.open("data") as f:
240
+ return pickle.load(f)
241
+ else:
242
+ with open(path, "rb") as f:
243
+ return pickle.load(f)
baselines/common/models.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from baselines.a2c import utils
6
+ from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
7
+ from baselines.common.mpi_running_mean_std import RunningMeanStd
8
+ from keras import layers
9
+ from itertools import combinations
10
+
11
+ mapping = {}
12
+
13
+
14
+ def register(name):
15
+ def _thunk(func):
16
+ mapping[name] = func
17
+ return func
18
+
19
+ return _thunk
20
+
21
+
22
+ def nature_cnn(unscaled_images, **conv_kwargs):
23
+ """
24
+ CNN from Nature paper.
25
+ """
26
+ scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
27
+ activ = tf.nn.relu
28
+ h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
29
+ **conv_kwargs))
30
+ h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
31
+ h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
32
+ h3 = conv_to_fc(h3)
33
+ return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
34
+
35
+
36
+ def build_impala_cnn(unscaled_images, depths=[16, 32, 32], **conv_kwargs):
37
+ """
38
+ Model used in the paper "IMPALA: Scalable Distributed Deep-RL with
39
+ Importance Weighted Actor-Learner Architectures" https://arxiv.org/abs/1802.01561
40
+ """
41
+
42
+ layer_num = 0
43
+
44
+ def get_layer_num_str():
45
+ nonlocal layer_num
46
+ num_str = str(layer_num)
47
+ layer_num += 1
48
+ return num_str
49
+
50
+ def conv_layer(out, depth):
51
+ return tf.compat.v1.layers.conv2d(out, depth, 3, padding='same', name='layer_' + get_layer_num_str())
52
+
53
+ def residual_block(inputs):
54
+ try:
55
+ depth = inputs.get_shape()[-1].value
56
+ except:
57
+ depth = inputs.get_shape()[-1]
58
+
59
+ out = tf.nn.relu(inputs)
60
+
61
+ out = conv_layer(out, depth)
62
+ out = tf.nn.relu(out)
63
+ out = conv_layer(out, depth)
64
+ return out + inputs
65
+
66
+ def conv_sequence(inputs, depth):
67
+ out = conv_layer(inputs, depth)
68
+ out = tf.compat.v1.layers.max_pooling2d(out, pool_size=3, strides=2, padding='same')
69
+ out = residual_block(out)
70
+ out = residual_block(out)
71
+ return out
72
+
73
+ out = tf.cast(unscaled_images, tf.float32) / 255.
74
+
75
+ for depth in depths:
76
+ out = conv_sequence(out, depth)
77
+
78
+ out = tf.compat.v1.layers.flatten(out)
79
+ out = tf.nn.relu(out)
80
+ out = tf.compat.v1.layers.dense(out, 256, activation=tf.nn.relu, name='layer_' + get_layer_num_str())
81
+
82
+ return out
83
+
84
+
85
+ def build_skill_impala_cnn(unscaled_images, depths=[16, 32, 32], emb_dim=256, num_embeddings=8, seed=0,
86
+ **conv_kwargs):
87
+ """
88
+ Modified impala cnn model by adding the skill module
89
+ """
90
+
91
+ layer_num = 0
92
+
93
+ def get_layer_num_str():
94
+ nonlocal layer_num
95
+ num_str = str(layer_num)
96
+ layer_num += 1
97
+ return num_str
98
+
99
+ def conv_layer(out, depth):
100
+ return tf.compat.v1.layers.conv2d(out, depth, 3, padding='same', name='layer_' + get_layer_num_str())
101
+
102
+ def residual_block(inputs):
103
+ # depth = inputs.get_shape()[-1].value
104
+ depth = inputs.get_shape()[-1]
105
+
106
+ out = tf.nn.relu(inputs)
107
+
108
+ out = conv_layer(out, depth)
109
+ out = tf.nn.relu(out)
110
+ out = conv_layer(out, depth)
111
+ return out + inputs
112
+
113
+ def conv_sequence(inputs, depth):
114
+ out = conv_layer(inputs, depth)
115
+ out = tf.compat.v1.layers.max_pooling2d(out, pool_size=3, strides=2, padding='same')
116
+ out = residual_block(out)
117
+ out = residual_block(out)
118
+ return out
119
+
120
+ out = tf.cast(unscaled_images, tf.float32) / 255.
121
+
122
+ for depth in depths:
123
+ out = conv_sequence(out, depth)
124
+
125
+ out = tf.compat.v1.layers.flatten(out)
126
+ out = tf.nn.relu(out)
127
+ pure_out = tf.compat.v1.layers.dense(out, emb_dim, activation=tf.nn.relu, name='layer_' + get_layer_num_str())
128
+
129
+ # skill module
130
+ skill_out = tf.compat.v1.layers.dense(pure_out, emb_dim // 2, activation=None, name='layer_' + get_layer_num_str())
131
+ skill_out = tf.compat.v1.layers.dense(skill_out, 2, activation=None, name='layer_' + get_layer_num_str())
132
+ vq_layer = VectorQuantizer(num_embeddings, 2, seed=seed, name="vector_quantizer")
133
+ vq_out, pure_vq_out, encoding_indices = vq_layer(skill_out)
134
+
135
+ encoding_indices_ = tf.cast(
136
+ tf.tile(encoding_indices / vq_layer.num_embeddings, tf.constant([1, emb_dim], tf.int32)), tf.float32)
137
+
138
+ # add the normalized skill indices to features
139
+ out = tf.math.add(pure_out, encoding_indices_)
140
+
141
+ return out, skill_out, pure_out, vq_out, pure_vq_out, vq_layer.embeddings, encoding_indices
142
+
143
+
144
+ @register("mlp")
145
+ def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
146
+ """
147
+ Stack of fully-connected layers to be used in a policy / q-function approximator
148
+
149
+ Parameters:
150
+ ----------
151
+
152
+ num_layers: int number of fully-connected layers (default: 2)
153
+
154
+ num_hidden: int size of fully-connected layers (default: 64)
155
+
156
+ activation: activation function (default: tf.tanh)
157
+
158
+ Returns:
159
+ -------
160
+
161
+ function that builds fully connected network with a given input tensor / placeholder
162
+ """
163
+
164
+ def network_fn(X):
165
+ h = tf.compat.v1.layers.flatten(X)
166
+ for i in range(num_layers):
167
+ h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
168
+ if layer_norm:
169
+ h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
170
+ h = activation(h)
171
+
172
+ return h
173
+
174
+ return network_fn
175
+
176
+
177
+ @register("cnn")
178
+ def cnn(**conv_kwargs):
179
+ def network_fn(X):
180
+ return nature_cnn(X, **conv_kwargs)
181
+
182
+ return network_fn
183
+
184
+
185
+ @register("impala_cnn")
186
+ def impala_cnn(**conv_kwargs):
187
+ def network_fn(X):
188
+ return build_impala_cnn(X)
189
+
190
+ return network_fn
191
+
192
+
193
+ @register("cnn_small")
194
+ def cnn_small(**conv_kwargs):
195
+ def network_fn(X):
196
+ h = tf.cast(X, tf.float32) / 255.
197
+
198
+ activ = tf.nn.relu
199
+ h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
200
+ h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
201
+ h = conv_to_fc(h)
202
+ h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
203
+ return h
204
+
205
+ return network_fn
206
+
207
+
208
+ @register("lstm")
209
+ def lstm(nlstm=128, layer_norm=False):
210
+ """
211
+ Builds LSTM (Long-Short Term Memory) network to be used in a policy.
212
+ Note that the resulting function returns not only the output of the LSTM
213
+ (i.e. hidden state of lstm for each step in the sequence), but also a dictionary
214
+ with auxiliary tensors to be set as policy attributes.
215
+
216
+ Specifically,
217
+ S is a placeholder to feed current state (LSTM state has to be managed outside policy)
218
+ M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
219
+ initial_state is a numpy array containing initial lstm state (usually zeros)
220
+ state is the output LSTM state (to be fed into S at the next call)
221
+
222
+
223
+ An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
224
+
225
+ Parameters:
226
+ ----------
227
+
228
+ nlstm: int LSTM hidden state size
229
+
230
+ layer_norm: bool if True, layer-normalized version of LSTM is used
231
+
232
+ Returns:
233
+ -------
234
+
235
+ function that builds LSTM with a given input tensor / placeholder
236
+ """
237
+
238
+ def network_fn(X, nenv=1):
239
+ nbatch = X.shape[0]
240
+ nsteps = nbatch // nenv
241
+
242
+ h = tf.compat.v1.layers.flatten(X)
243
+
244
+ M = tf.compat.v1.placeholder(tf.float32, [nbatch]) # mask (done t-1)
245
+ S = tf.compat.v1.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
246
+
247
+ xs = batch_to_seq(h, nenv, nsteps)
248
+ ms = batch_to_seq(M, nenv, nsteps)
249
+
250
+ if layer_norm:
251
+ h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
252
+ else:
253
+ h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
254
+
255
+ h = seq_to_batch(h5)
256
+ initial_state = np.zeros(S.shape.as_list(), dtype=float)
257
+
258
+ return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
259
+
260
+ return network_fn
261
+
262
+
263
+ @register("cnn_lstm")
264
+ def cnn_lstm(nlstm=128, layer_norm=False, conv_fn=nature_cnn, **conv_kwargs):
265
+ def network_fn(X, nenv=1):
266
+ nbatch = X.shape[0]
267
+ nsteps = nbatch // nenv
268
+
269
+ h = conv_fn(X, **conv_kwargs)
270
+
271
+ M = tf.compat.v1.placeholder(tf.float32, [nbatch]) # mask (done t-1)
272
+ S = tf.compat.v1.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
273
+
274
+ xs = batch_to_seq(h, nenv, nsteps)
275
+ ms = batch_to_seq(M, nenv, nsteps)
276
+
277
+ if layer_norm:
278
+ h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
279
+ else:
280
+ h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
281
+
282
+ h = seq_to_batch(h5)
283
+ initial_state = np.zeros(S.shape.as_list(), dtype=float)
284
+
285
+ return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
286
+
287
+ return network_fn
288
+
289
+
290
+ @register("impala_cnn_lstm")
291
+ def impala_cnn_lstm():
292
+ return cnn_lstm(nlstm=256, conv_fn=build_impala_cnn)
293
+
294
+
295
+ @register("cnn_lnlstm")
296
+ def cnn_lnlstm(nlstm=128, **conv_kwargs):
297
+ return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
298
+
299
+
300
+ @register("conv_only")
301
+ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
302
+ '''
303
+ convolutions-only net
304
+
305
+ Parameters:
306
+ ----------
307
+
308
+ conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
309
+
310
+ Returns:
311
+
312
+ function that takes tensorflow tensor as input and returns the output of the last convolutional layer
313
+
314
+ '''
315
+
316
+ def network_fn(X):
317
+ out = tf.cast(X, tf.float32) / 255.
318
+ with tf.compat.v1.variable_scope("convnet"):
319
+ for num_outputs, kernel_size, stride in convs:
320
+ out = tf.contrib.layers.convolution2d(out,
321
+ num_outputs=num_outputs,
322
+ kernel_size=kernel_size,
323
+ stride=stride,
324
+ activation_fn=tf.nn.relu,
325
+ **conv_kwargs)
326
+
327
+ return out
328
+
329
+ return network_fn
330
+
331
+
332
+ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
333
+ rms = RunningMeanStd(shape=x.shape[1:])
334
+ norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
335
+ return norm_x, rms
336
+
337
+
338
+ def get_network_builder(name):
339
+ """
340
+ If you want to register your own network outside models.py, you just need:
341
+
342
+ Usage Example:
343
+ -------------
344
+ from baselines.common.models import register
345
+ @register("your_network_name")
346
+ def your_network_define(**net_kwargs):
347
+ ...
348
+ return network_fn
349
+
350
+ """
351
+ if callable(name):
352
+ return name
353
+ elif name in mapping:
354
+ return mapping[name]
355
+ else:
356
+ raise ValueError('Unknown network type: {}'.format(name))
357
+
358
+
359
+ class VectorQuantizer(layers.Layer):
360
+ def __init__(self, num_embeddings, embedding_dim, seed=0, **kwargs):
361
+ super().__init__(**kwargs)
362
+ self.embedding_dim = embedding_dim
363
+ self.num_embeddings = num_embeddings
364
+
365
+ # Initialize the embeddings which we will quantize.
366
+ w_init = tf.compat.v1.random_uniform_initializer(minval=-1 / num_embeddings, maxval=1 / num_embeddings,
367
+ seed=seed)
368
+
369
+ self.embeddings = tf.compat.v1.get_variable(
370
+ initializer=w_init(
371
+ shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
372
+ ),
373
+ trainable=True,
374
+ name="embeddings_vqvae",
375
+ )
376
+
377
+ def call(self, x):
378
+ # Calculate the input shape of the inputs and
379
+ # then flatten the inputs keeping `embedding_dim` intact.
380
+ input_shape = tf.shape(input=x)
381
+ flattened = tf.reshape(x, [-1, self.embedding_dim])
382
+
383
+ # Quantization.
384
+ encoding_indices = self.get_code_indices(flattened)
385
+ encoding_indices = tf.reshape(encoding_indices, [input_shape[0], -1])
386
+ encodings = tf.one_hot(encoding_indices, self.num_embeddings)
387
+ quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
388
+ quantized = tf.reshape(quantized, input_shape)
389
+
390
+ # Straight-through estimator.
391
+ quantized_ = x + tf.stop_gradient(quantized - x)
392
+
393
+ return quantized_, quantized, encoding_indices
394
+
395
+ def get_code_indices(self, flattened_inputs):
396
+ # Calculate L2-normalized distance between the inputs and the codes.
397
+ similarity = tf.matmul(flattened_inputs, self.embeddings)
398
+ distances = (
399
+ tf.reduce_sum(input_tensor=flattened_inputs ** 2, axis=1, keepdims=True)
400
+ + tf.reduce_sum(input_tensor=self.embeddings ** 2, axis=0)
401
+ - 2 * similarity
402
+ )
403
+
404
+ # Derive the indices for minimum distances.
405
+ encoding_indices = tf.argmin(input=distances, axis=1)
406
+ return encoding_indices
baselines/common/mpi_adam.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import baselines.common.tf_util as U
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ try:
5
+ from mpi4py import MPI
6
+ except ImportError:
7
+ MPI = None
8
+
9
+
10
+ class MpiAdam(object):
11
+ def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
12
+ self.var_list = var_list
13
+ self.beta1 = beta1
14
+ self.beta2 = beta2
15
+ self.epsilon = epsilon
16
+ self.scale_grad_by_procs = scale_grad_by_procs
17
+ size = sum(U.numel(v) for v in var_list)
18
+ self.m = np.zeros(size, 'float32')
19
+ self.v = np.zeros(size, 'float32')
20
+ self.t = 0
21
+ self.setfromflat = U.SetFromFlat(var_list)
22
+ self.getflat = U.GetFlat(var_list)
23
+ self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
24
+
25
+ def update(self, localg, stepsize):
26
+ if self.t % 100 == 0:
27
+ self.check_synced()
28
+ localg = localg.astype('float32')
29
+ if self.comm is not None:
30
+ globalg = np.zeros_like(localg)
31
+ self.comm.Allreduce(localg, globalg, op=MPI.SUM)
32
+ if self.scale_grad_by_procs:
33
+ globalg /= self.comm.Get_size()
34
+ else:
35
+ globalg = np.copy(localg)
36
+
37
+ self.t += 1
38
+ a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
39
+ self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
40
+ self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
41
+ step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon)
42
+ self.setfromflat(self.getflat() + step)
43
+
44
+ def sync(self):
45
+ if self.comm is None:
46
+ return
47
+ theta = self.getflat()
48
+ self.comm.Bcast(theta, root=0)
49
+ self.setfromflat(theta)
50
+
51
+ def check_synced(self):
52
+ if self.comm is None:
53
+ return
54
+ if self.comm.Get_rank() == 0: # this is root
55
+ theta = self.getflat()
56
+ self.comm.Bcast(theta, root=0)
57
+ else:
58
+ thetalocal = self.getflat()
59
+ thetaroot = np.empty_like(thetalocal)
60
+ self.comm.Bcast(thetaroot, root=0)
61
+ assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
62
+
63
+ @U.in_session
64
+ def test_MpiAdam():
65
+ np.random.seed(0)
66
+ tf.compat.v1.set_random_seed(0)
67
+
68
+ a = tf.Variable(np.random.randn(3).astype('float32'))
69
+ b = tf.Variable(np.random.randn(2,5).astype('float32'))
70
+ loss = tf.reduce_sum(input_tensor=tf.square(a)) + tf.reduce_sum(input_tensor=tf.sin(b))
71
+
72
+ stepsize = 1e-2
73
+ update_op = tf.compat.v1.train.AdamOptimizer(stepsize).minimize(loss)
74
+ do_update = U.function([], loss, updates=[update_op])
75
+
76
+ tf.compat.v1.get_default_session().run(tf.compat.v1.global_variables_initializer())
77
+ losslist_ref = []
78
+ for i in range(10):
79
+ l = do_update()
80
+ print(i, l)
81
+ losslist_ref.append(l)
82
+
83
+
84
+
85
+ tf.compat.v1.set_random_seed(0)
86
+ tf.compat.v1.get_default_session().run(tf.compat.v1.global_variables_initializer())
87
+
88
+ var_list = [a,b]
89
+ lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
90
+ adam = MpiAdam(var_list)
91
+
92
+ losslist_test = []
93
+ for i in range(10):
94
+ l,g = lossandgrad()
95
+ adam.update(g, stepsize)
96
+ print(i,l)
97
+ losslist_test.append(l)
98
+
99
+ np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ test_MpiAdam()
baselines/common/mpi_adam_optimizer.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from baselines.common import tf_util as U
4
+ from baselines.common.tests.test_with_mpi import with_mpi
5
+ from baselines import logger
6
+ try:
7
+ from mpi4py import MPI
8
+ except ImportError:
9
+ MPI = None
10
+
11
+ class MpiAdamOptimizer(tf.compat.v1.train.AdamOptimizer):
12
+ """Adam optimizer that averages gradients across mpi processes."""
13
+ def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
14
+ self.comm = comm
15
+ self.grad_clip = grad_clip
16
+ self.mpi_rank_weight = mpi_rank_weight
17
+ tf.compat.v1.train.AdamOptimizer.__init__(self, **kwargs)
18
+ def compute_gradients(self, loss, var_list, **kwargs):
19
+ grads_and_vars = tf.compat.v1.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
20
+ grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
21
+ flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight
22
+ shapes = [v.shape.as_list() for g, v in grads_and_vars]
23
+ sizes = [int(np.prod(s)) for s in shapes]
24
+
25
+ total_weight = np.zeros(1, np.float32)
26
+ self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM)
27
+ total_weight = total_weight[0]
28
+
29
+ buf = np.zeros(sum(sizes), np.float32)
30
+ countholder = [0] # Counts how many times _collect_grads has been called
31
+ stat = tf.reduce_sum(input_tensor=grads_and_vars[0][1]) # sum of first variable
32
+ def _collect_grads(flat_grad, np_stat):
33
+ if self.grad_clip is not None:
34
+ gradnorm = np.linalg.norm(flat_grad)
35
+ if gradnorm > 1:
36
+ flat_grad /= gradnorm
37
+ logger.logkv_mean('gradnorm', gradnorm)
38
+ logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
39
+ self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
40
+ np.divide(buf, float(total_weight), out=buf)
41
+ if countholder[0] % 100 == 0:
42
+ check_synced(np_stat, self.comm)
43
+ countholder[0] += 1
44
+ return buf
45
+
46
+ avg_flat_grad = tf.compat.v1.py_func(_collect_grads, [flat_grad, stat], tf.float32)
47
+ avg_flat_grad.set_shape(flat_grad.shape)
48
+ avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
49
+ avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
50
+ for g, (_, v) in zip(avg_grads, grads_and_vars)]
51
+ return avg_grads_and_vars
52
+
53
+ def check_synced(localval, comm=None):
54
+ """
55
+ It's common to forget to initialize your variables to the same values, or
56
+ (less commonly) if you update them in some other way than adam, to get them out of sync.
57
+ This function checks that variables on all MPI workers are the same, and raises
58
+ an AssertionError otherwise
59
+
60
+ Arguments:
61
+ comm: MPI communicator
62
+ localval: list of local variables (list of variables on current worker to be compared with the other workers)
63
+ """
64
+ comm = comm or MPI.COMM_WORLD
65
+ vals = comm.gather(localval)
66
+ if comm.rank == 0:
67
+ assert all(val==vals[0] for val in vals[1:]),\
68
+ 'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals)
69
+
70
+ @with_mpi(timeout=5)
71
+ def test_nonfreeze():
72
+ np.random.seed(0)
73
+ tf.compat.v1.set_random_seed(0)
74
+
75
+ a = tf.Variable(np.random.randn(3).astype('float32'))
76
+ b = tf.Variable(np.random.randn(2,5).astype('float32'))
77
+ loss = tf.reduce_sum(input_tensor=tf.square(a)) + tf.reduce_sum(input_tensor=tf.sin(b))
78
+
79
+ stepsize = 1e-2
80
+ # for some reason the session config with inter_op_parallelism_threads was causing
81
+ # nested sess.run calls to freeze
82
+ config = tf.compat.v1.ConfigProto(inter_op_parallelism_threads=1)
83
+ sess = U.get_session(config=config)
84
+ update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss)
85
+ sess.run(tf.compat.v1.global_variables_initializer())
86
+ losslist_ref = []
87
+ for i in range(100):
88
+ l,_ = sess.run([loss, update_op])
89
+ print(i, l)
90
+ losslist_ref.append(l)
baselines/common/mpi_fork.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, subprocess, sys
2
+
3
+ def mpi_fork(n, bind_to_core=False):
4
+ """Re-launches the current script with workers
5
+ Returns "parent" for original parent, "child" for MPI children
6
+ """
7
+ if n<=1:
8
+ return "child"
9
+ if os.getenv("IN_MPI") is None:
10
+ env = os.environ.copy()
11
+ env.update(
12
+ MKL_NUM_THREADS="1",
13
+ OMP_NUM_THREADS="1",
14
+ IN_MPI="1"
15
+ )
16
+ args = ["mpirun", "-np", str(n)]
17
+ if bind_to_core:
18
+ args += ["-bind-to", "core"]
19
+ args += [sys.executable] + sys.argv
20
+ subprocess.check_call(args, env=env)
21
+ return "parent"
22
+ else:
23
+ return "child"
baselines/common/mpi_moments.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mpi4py import MPI
2
+ import numpy as np
3
+ from baselines.common import zipsame
4
+
5
+
6
+ def mpi_mean(x, axis=0, comm=None, keepdims=False):
7
+ x = np.asarray(x)
8
+ assert x.ndim > 0
9
+ if comm is None: comm = MPI.COMM_WORLD
10
+ xsum = x.sum(axis=axis, keepdims=keepdims)
11
+ n = xsum.size
12
+ localsum = np.zeros(n+1, x.dtype)
13
+ localsum[:n] = xsum.ravel()
14
+ localsum[n] = x.shape[axis]
15
+ # globalsum = np.zeros_like(localsum)
16
+ # comm.Allreduce(localsum, globalsum, op=MPI.SUM)
17
+ globalsum = comm.allreduce(localsum, op=MPI.SUM)
18
+ return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
19
+
20
+ def mpi_moments(x, axis=0, comm=None, keepdims=False):
21
+ x = np.asarray(x)
22
+ assert x.ndim > 0
23
+ mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True)
24
+ sqdiffs = np.square(x - mean)
25
+ meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True)
26
+ assert count1 == count
27
+ std = np.sqrt(meansqdiff)
28
+ if not keepdims:
29
+ newshape = mean.shape[:axis] + mean.shape[axis+1:]
30
+ mean = mean.reshape(newshape)
31
+ std = std.reshape(newshape)
32
+ return mean, std, count
33
+
34
+
35
+ def test_runningmeanstd():
36
+ import subprocess
37
+ subprocess.check_call(['mpirun', '-np', '3',
38
+ 'python','-c',
39
+ 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()'])
40
+
41
+ def _helper_runningmeanstd():
42
+ comm = MPI.COMM_WORLD
43
+ np.random.seed(0)
44
+ for (triple,axis) in [
45
+ ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0),
46
+ ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0),
47
+ ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1),
48
+ ]:
49
+
50
+
51
+ x = np.concatenate(triple, axis=axis)
52
+ ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
53
+
54
+
55
+ ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis)
56
+
57
+ for (a1,a2) in zipsame(ms1, ms2):
58
+ print(a1, a2)
59
+ assert np.allclose(a1, a2)
60
+ print("ok!")
61
+
baselines/common/mpi_running_mean_std.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from mpi4py import MPI
3
+ except ImportError:
4
+ MPI = None
5
+
6
+ import tensorflow as tf, baselines.common.tf_util as U, numpy as np
7
+
8
+ class RunningMeanStd(object):
9
+ # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
10
+ def __init__(self, epsilon=1e-2, shape=()):
11
+
12
+ self._sum = tf.compat.v1.get_variable(
13
+ dtype=tf.float64,
14
+ shape=shape,
15
+ initializer=tf.compat.v1.constant_initializer(0.0),
16
+ name="runningsum", trainable=False)
17
+ self._sumsq = tf.compat.v1.get_variable(
18
+ dtype=tf.float64,
19
+ shape=shape,
20
+ initializer=tf.compat.v1.constant_initializer(epsilon),
21
+ name="runningsumsq", trainable=False)
22
+ self._count = tf.compat.v1.get_variable(
23
+ dtype=tf.float64,
24
+ shape=(),
25
+ initializer=tf.compat.v1.constant_initializer(epsilon),
26
+ name="count", trainable=False)
27
+ self.shape = shape
28
+
29
+ self.mean = tf.cast(self._sum / self._count, dtype=tf.float32)
30
+ self.std = tf.sqrt( tf.maximum( tf.cast(self._sumsq / self._count, dtype=tf.float32) - tf.square(self.mean) , 1e-2 ))
31
+
32
+ newsum = tf.compat.v1.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
33
+ newsumsq = tf.compat.v1.placeholder(shape=self.shape, dtype=tf.float64, name='var')
34
+ newcount = tf.compat.v1.placeholder(shape=[], dtype=tf.float64, name='count')
35
+ self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
36
+ updates=[tf.compat.v1.assign_add(self._sum, newsum),
37
+ tf.compat.v1.assign_add(self._sumsq, newsumsq),
38
+ tf.compat.v1.assign_add(self._count, newcount)])
39
+
40
+
41
+ def update(self, x):
42
+ x = x.astype('float64')
43
+ n = int(np.prod(self.shape))
44
+ totalvec = np.zeros(n*2+1, 'float64')
45
+ addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
46
+ if MPI is not None:
47
+ MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
48
+ self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
49
+
50
+ @U.in_session
51
+ def test_runningmeanstd():
52
+ for (x1, x2, x3) in [
53
+ (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
54
+ (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
55
+ ]:
56
+
57
+ rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
58
+ U.initialize()
59
+
60
+ x = np.concatenate([x1, x2, x3], axis=0)
61
+ ms1 = [x.mean(axis=0), x.std(axis=0)]
62
+ rms.update(x1)
63
+ rms.update(x2)
64
+ rms.update(x3)
65
+ ms2 = [rms.mean.eval(), rms.std.eval()]
66
+
67
+ assert np.allclose(ms1, ms2)
68
+
69
+ @U.in_session
70
+ def test_dist():
71
+ np.random.seed(0)
72
+ p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
73
+ q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
74
+
75
+ # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
76
+ # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
77
+
78
+ comm = MPI.COMM_WORLD
79
+ assert comm.Get_size()==2
80
+ if comm.Get_rank()==0:
81
+ x1,x2,x3 = p1,p2,p3
82
+ elif comm.Get_rank()==1:
83
+ x1,x2,x3 = q1,q2,q3
84
+ else:
85
+ assert False
86
+
87
+ rms = RunningMeanStd(epsilon=0.0, shape=(1,))
88
+ U.initialize()
89
+
90
+ rms.update(x1)
91
+ rms.update(x2)
92
+ rms.update(x3)
93
+
94
+ bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
95
+
96
+ def checkallclose(x,y):
97
+ print(x,y)
98
+ return np.allclose(x,y)
99
+
100
+ assert checkallclose(
101
+ bigvec.mean(axis=0),
102
+ rms.mean.eval(),
103
+ )
104
+ assert checkallclose(
105
+ bigvec.std(axis=0),
106
+ rms.std.eval(),
107
+ )
108
+
109
+
110
+ if __name__ == "__main__":
111
+ # Run with mpirun -np 2 python <filename>
112
+ test_dist()
baselines/common/mpi_util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import os, numpy as np
3
+ import platform
4
+ import shutil
5
+ import subprocess
6
+ import warnings
7
+ import sys
8
+
9
+ try:
10
+ from mpi4py import MPI
11
+ except ImportError:
12
+ MPI = None
13
+
14
+
15
+ def sync_from_root(sess, variables, comm=None):
16
+ """
17
+ Send the root node's parameters to every worker.
18
+ Arguments:
19
+ sess: the TensorFlow session.
20
+ variables: all parameter variables including optimizer's
21
+ """
22
+ if comm is None: comm = MPI.COMM_WORLD
23
+ import tensorflow as tf
24
+ values = comm.bcast(sess.run(variables))
25
+ sess.run([tf.compat.v1.assign(var, val)
26
+ for (var, val) in zip(variables, values)])
27
+
28
+ def gpu_count():
29
+ """
30
+ Count the GPUs on this machine.
31
+ """
32
+ if shutil.which('nvidia-smi') is None:
33
+ return 0
34
+ output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
35
+ return max(0, len(output.split(b'\n')) - 2)
36
+
37
+ def setup_mpi_gpus():
38
+ """
39
+ Set CUDA_VISIBLE_DEVICES to MPI rank if not already set
40
+ """
41
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
42
+ if sys.platform == 'darwin': # This Assumes if you're on OSX you're just
43
+ ids = [] # doing a smoke test and don't want GPUs
44
+ else:
45
+ lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD)
46
+ ids = [lrank]
47
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids))
48
+
49
+ def get_local_rank_size(comm):
50
+ """
51
+ Returns the rank of each process on its machine
52
+ The processes on a given machine will be assigned ranks
53
+ 0, 1, 2, ..., N-1,
54
+ where N is the number of processes on this machine.
55
+
56
+ Useful if you want to assign one gpu per machine
57
+ """
58
+ this_node = platform.node()
59
+ ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
60
+ node2rankssofar = defaultdict(int)
61
+ local_rank = None
62
+ for (rank, node) in ranks_nodes:
63
+ if rank == comm.Get_rank():
64
+ local_rank = node2rankssofar[node]
65
+ node2rankssofar[node] += 1
66
+ assert local_rank is not None
67
+ return local_rank, node2rankssofar[this_node]
68
+
69
+ def share_file(comm, path):
70
+ """
71
+ Copies the file from rank 0 to all other ranks
72
+ Puts it in the same place on all machines
73
+ """
74
+ localrank, _ = get_local_rank_size(comm)
75
+ if comm.Get_rank() == 0:
76
+ with open(path, 'rb') as fh:
77
+ data = fh.read()
78
+ comm.bcast(data)
79
+ else:
80
+ data = comm.bcast(None)
81
+ if localrank == 0:
82
+ os.makedirs(os.path.dirname(path), exist_ok=True)
83
+ with open(path, 'wb') as fh:
84
+ fh.write(data)
85
+ comm.Barrier()
86
+
87
+ def dict_gather(comm, d, op='mean', assert_all_have_data=True):
88
+ """
89
+ Perform a reduction operation over dicts
90
+ """
91
+ if comm is None: return d
92
+ alldicts = comm.allgather(d)
93
+ size = comm.size
94
+ k2li = defaultdict(list)
95
+ for d in alldicts:
96
+ for (k,v) in d.items():
97
+ k2li[k].append(v)
98
+ result = {}
99
+ for (k,li) in k2li.items():
100
+ if assert_all_have_data:
101
+ assert len(li)==size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k)
102
+ if op=='mean':
103
+ result[k] = np.mean(li, axis=0)
104
+ elif op=='sum':
105
+ result[k] = np.sum(li, axis=0)
106
+ else:
107
+ assert 0, op
108
+ return result
109
+
110
+ def mpi_weighted_mean(comm, local_name2valcount):
111
+ """
112
+ Perform a weighted average over dicts that are each on a different node
113
+ Input: local_name2valcount: dict mapping key -> (value, count)
114
+ Returns: key -> mean
115
+ """
116
+ all_name2valcount = comm.gather(local_name2valcount)
117
+ if comm.rank == 0:
118
+ name2sum = defaultdict(float)
119
+ name2count = defaultdict(float)
120
+ for n2vc in all_name2valcount:
121
+ for (name, (val, count)) in n2vc.items():
122
+ try:
123
+ val = float(val)
124
+ except ValueError:
125
+ if comm.rank == 0:
126
+ warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val))
127
+ else:
128
+ name2sum[name] += val * count
129
+ name2count[name] += count
130
+ return {name : name2sum[name] / name2count[name] for name in name2sum}
131
+ else:
132
+ return {}
133
+
baselines/common/plot_util.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import os.path as osp
3
+ import json
4
+ import os
5
+ import numpy as np
6
+ import pandas
7
+ from collections import defaultdict, namedtuple
8
+ from baselines.bench import monitor
9
+ from baselines.logger import read_json, read_csv
10
+
11
+ def smooth(y, radius, mode='two_sided', valid_only=False):
12
+ '''
13
+ Smooth signal y, where radius is determines the size of the window
14
+
15
+ mode='twosided':
16
+ average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
17
+ mode='causal':
18
+ average over the window [max(index - radius, 0), index]
19
+
20
+ valid_only: put nan in entries where the full-sized window is not available
21
+
22
+ '''
23
+ assert mode in ('two_sided', 'causal')
24
+ if len(y) < 2*radius+1:
25
+ return np.ones_like(y) * y.mean()
26
+ elif mode == 'two_sided':
27
+ convkernel = np.ones(2 * radius+1)
28
+ out = np.convolve(y, convkernel,mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same')
29
+ if valid_only:
30
+ out[:radius] = out[-radius:] = np.nan
31
+ elif mode == 'causal':
32
+ convkernel = np.ones(radius)
33
+ out = np.convolve(y, convkernel,mode='full') / np.convolve(np.ones_like(y), convkernel, mode='full')
34
+ out = out[:-radius+1]
35
+ if valid_only:
36
+ out[:radius] = np.nan
37
+ return out
38
+
39
+ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
40
+ '''
41
+ perform one-sided (causal) EMA (exponential moving average)
42
+ smoothing and resampling to an even grid with n points.
43
+ Does not do extrapolation, so we assume
44
+ xolds[0] <= low && high <= xolds[-1]
45
+
46
+ Arguments:
47
+
48
+ xolds: array or list - x values of data. Needs to be sorted in ascending order
49
+ yolds: array of list - y values of data. Has to have the same length as xolds
50
+
51
+ low: float - min value of the new x grid. By default equals to xolds[0]
52
+ high: float - max value of the new x grid. By default equals to xolds[-1]
53
+
54
+ n: int - number of points in new x grid
55
+
56
+ decay_steps: float - EMA decay factor, expressed in new x grid steps.
57
+
58
+ low_counts_threshold: float or int
59
+ - y values with counts less than this value will be set to NaN
60
+
61
+ Returns:
62
+ tuple sum_ys, count_ys where
63
+ xs - array with new x grid
64
+ ys - array of EMA of y at each point of the new x grid
65
+ count_ys - array of EMA of y counts at each point of the new x grid
66
+
67
+ '''
68
+
69
+ low = xolds[0] if low is None else low
70
+ high = xolds[-1] if high is None else high
71
+
72
+ assert xolds[0] <= low, 'low = {} < xolds[0] = {} - extrapolation not permitted!'.format(low, xolds[0])
73
+ assert xolds[-1] >= high, 'high = {} > xolds[-1] = {} - extrapolation not permitted!'.format(high, xolds[-1])
74
+ assert len(xolds) == len(yolds), 'length of xolds ({}) and yolds ({}) do not match!'.format(len(xolds), len(yolds))
75
+
76
+
77
+ xolds = xolds.astype('float64')
78
+ yolds = yolds.astype('float64')
79
+
80
+ luoi = 0 # last unused old index
81
+ sum_y = 0.
82
+ count_y = 0.
83
+ xnews = np.linspace(low, high, n)
84
+ decay_period = (high - low) / (n - 1) * decay_steps
85
+ interstep_decay = np.exp(- 1. / decay_steps)
86
+ sum_ys = np.zeros_like(xnews)
87
+ count_ys = np.zeros_like(xnews)
88
+ for i in range(n):
89
+ xnew = xnews[i]
90
+ sum_y *= interstep_decay
91
+ count_y *= interstep_decay
92
+ while True:
93
+ if luoi >= len(xolds):
94
+ break
95
+ xold = xolds[luoi]
96
+ if xold <= xnew:
97
+ decay = np.exp(- (xnew - xold) / decay_period)
98
+ sum_y += decay * yolds[luoi]
99
+ count_y += decay
100
+ luoi += 1
101
+ else:
102
+ break
103
+ sum_ys[i] = sum_y
104
+ count_ys[i] = count_y
105
+
106
+ ys = sum_ys / count_ys
107
+ ys[count_ys < low_counts_threshold] = np.nan
108
+
109
+ return xnews, ys, count_ys
110
+
111
+ def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
112
+ '''
113
+ perform symmetric EMA (exponential moving average)
114
+ smoothing and resampling to an even grid with n points.
115
+ Does not do extrapolation, so we assume
116
+ xolds[0] <= low && high <= xolds[-1]
117
+
118
+ Arguments:
119
+
120
+ xolds: array or list - x values of data. Needs to be sorted in ascending order
121
+ yolds: array of list - y values of data. Has to have the same length as xolds
122
+
123
+ low: float - min value of the new x grid. By default equals to xolds[0]
124
+ high: float - max value of the new x grid. By default equals to xolds[-1]
125
+
126
+ n: int - number of points in new x grid
127
+
128
+ decay_steps: float - EMA decay factor, expressed in new x grid steps.
129
+
130
+ low_counts_threshold: float or int
131
+ - y values with counts less than this value will be set to NaN
132
+
133
+ Returns:
134
+ tuple sum_ys, count_ys where
135
+ xs - array with new x grid
136
+ ys - array of EMA of y at each point of the new x grid
137
+ count_ys - array of EMA of y counts at each point of the new x grid
138
+
139
+ '''
140
+ xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold=0)
141
+ _, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold=0)
142
+ ys2 = ys2[::-1]
143
+ count_ys2 = count_ys2[::-1]
144
+ count_ys = count_ys1 + count_ys2
145
+ ys = (ys1 * count_ys1 + ys2 * count_ys2) / count_ys
146
+ ys[count_ys < low_counts_threshold] = np.nan
147
+ return xs, ys, count_ys
148
+
149
+ Result = namedtuple('Result', 'monitor progress dirname metadata')
150
+ Result.__new__.__defaults__ = (None,) * len(Result._fields)
151
+
152
+ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, verbose=False):
153
+ '''
154
+ load summaries of runs from a list of directories (including subdirectories)
155
+ Arguments:
156
+
157
+ enable_progress: bool - if True, will attempt to load data from progress.csv files (data saved by logger). Default: True
158
+
159
+ enable_monitor: bool - if True, will attempt to load data from monitor.csv files (data saved by Monitor environment wrapper). Default: True
160
+
161
+ verbose: bool - if True, will print out list of directories from which the data is loaded. Default: False
162
+
163
+
164
+ Returns:
165
+ List of Result objects with the following fields:
166
+ - dirname - path to the directory data was loaded from
167
+ - metadata - run metadata (such as command-line arguments and anything else in metadata.json file
168
+ - monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory)
169
+ - progress - if enable_progress is True, this field contains pandas dataframe with loaded progress.csv file
170
+ '''
171
+ import re
172
+ if isinstance(root_dir_or_dirs, str):
173
+ rootdirs = [osp.expanduser(root_dir_or_dirs)]
174
+ else:
175
+ rootdirs = [osp.expanduser(d) for d in root_dir_or_dirs]
176
+ allresults = []
177
+ for rootdir in rootdirs:
178
+ assert osp.exists(rootdir), "%s doesn't exist"%rootdir
179
+ for dirname, dirs, files in os.walk(rootdir):
180
+ if '-proc' in dirname:
181
+ files[:] = []
182
+ continue
183
+ monitor_re = re.compile(r'(\d+\.)?(\d+\.)?monitor\.csv')
184
+ if set(['metadata.json', 'monitor.json', 'progress.json', 'progress.csv']).intersection(files) or \
185
+ any([f for f in files if monitor_re.match(f)]): # also match monitor files like 0.1.monitor.csv
186
+ # used to be uncommented, which means do not go deeper than current directory if any of the data files
187
+ # are found
188
+ # dirs[:] = []
189
+ result = {'dirname' : dirname}
190
+ if "metadata.json" in files:
191
+ with open(osp.join(dirname, "metadata.json"), "r") as fh:
192
+ result['metadata'] = json.load(fh)
193
+ progjson = osp.join(dirname, "progress.json")
194
+ progcsv = osp.join(dirname, "progress.csv")
195
+ if enable_progress:
196
+ if osp.exists(progjson):
197
+ result['progress'] = pandas.DataFrame(read_json(progjson))
198
+ elif osp.exists(progcsv):
199
+ try:
200
+ result['progress'] = read_csv(progcsv)
201
+ except pandas.errors.EmptyDataError:
202
+ print('skipping progress file in ', dirname, 'empty data')
203
+ else:
204
+ if verbose: print('skipping %s: no progress file'%dirname)
205
+
206
+ if enable_monitor:
207
+ try:
208
+ result['monitor'] = pandas.DataFrame(monitor.load_results(dirname))
209
+ except monitor.LoadMonitorResultsError:
210
+ print('skipping %s: no monitor files'%dirname)
211
+ except Exception as e:
212
+ print('exception loading monitor file in %s: %s'%(dirname, e))
213
+
214
+ if result.get('monitor') is not None or result.get('progress') is not None:
215
+ allresults.append(Result(**result))
216
+ if verbose:
217
+ print('successfully loaded %s'%dirname)
218
+
219
+ if verbose: print('loaded %i results'%len(allresults))
220
+ return allresults
221
+
222
+ COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
223
+ 'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
224
+ 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
225
+
226
+
227
+ def default_xy_fn(r):
228
+ x = np.cumsum(r.monitor.l)
229
+ y = smooth(r.monitor.r, radius=10)
230
+ return x,y
231
+
232
+ def default_split_fn(r):
233
+ import re
234
+ # match name between slash and -<digits> at the end of the string
235
+ # (slash in the beginning or -<digits> in the end or either may be missing)
236
+ match = re.search(r'[^/-]+(?=(-\d+)?\Z)', r.dirname)
237
+ if match:
238
+ return match.group(0)
239
+
240
+ def plot_results(
241
+ allresults, *,
242
+ xy_fn=default_xy_fn,
243
+ split_fn=default_split_fn,
244
+ group_fn=default_split_fn,
245
+ average_group=False,
246
+ shaded_std=True,
247
+ shaded_err=True,
248
+ figsize=None,
249
+ legend_outside=False,
250
+ resample=0,
251
+ smooth_step=1.0,
252
+ tiling='vertical',
253
+ xlabel=None,
254
+ ylabel=None
255
+ ):
256
+ '''
257
+ Plot multiple Results objects
258
+
259
+ xy_fn: function Result -> x,y - function that converts results objects into tuple of x and y values.
260
+ By default, x is cumsum of episode lengths, and y is episode rewards
261
+
262
+ split_fn: function Result -> hashable - function that converts results objects into keys to split curves into sub-panels by.
263
+ That is, the results r for which split_fn(r) is different will be put on different sub-panels.
264
+ By default, the portion of r.dirname between last / and -<digits> is returned. The sub-panels are
265
+ stacked vertically in the figure.
266
+
267
+ group_fn: function Result -> hashable - function that converts results objects into keys to group curves by.
268
+ That is, the results r for which group_fn(r) is the same will be put into the same group.
269
+ Curves in the same group have the same color (if average_group is False), or averaged over
270
+ (if average_group is True). The default value is the same as default value for split_fn
271
+
272
+ average_group: bool - if True, will average the curves in the same group and plot the mean. Enables resampling
273
+ (if resample = 0, will use 512 steps)
274
+
275
+ shaded_std: bool - if True (default), the shaded region corresponding to standard deviation of the group of curves will be
276
+ shown (only applicable if average_group = True)
277
+
278
+ shaded_err: bool - if True (default), the shaded region corresponding to error in mean estimate of the group of curves
279
+ (that is, standard deviation divided by square root of number of curves) will be
280
+ shown (only applicable if average_group = True)
281
+
282
+ figsize: tuple or None - size of the resulting figure (including sub-panels). By default, width is 6 and height is 6 times number of
283
+ sub-panels.
284
+
285
+
286
+ legend_outside: bool - if True, will place the legend outside of the sub-panels.
287
+
288
+ resample: int - if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric
289
+ EMA smoothing (see the docstring for symmetric_ema).
290
+ Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default
291
+ value is 512.
292
+
293
+ smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step).
294
+ See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.
295
+
296
+ '''
297
+
298
+ if split_fn is None: split_fn = lambda _ : ''
299
+ if group_fn is None: group_fn = lambda _ : ''
300
+ sk2r = defaultdict(list) # splitkey2results
301
+ for result in allresults:
302
+ splitkey = split_fn(result)
303
+ sk2r[splitkey].append(result)
304
+ assert len(sk2r) > 0
305
+ assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
306
+ if tiling == 'vertical' or tiling is None:
307
+ nrows = len(sk2r)
308
+ ncols = 1
309
+ elif tiling == 'horizontal':
310
+ ncols = len(sk2r)
311
+ nrows = 1
312
+ elif tiling == 'symmetric':
313
+ import math
314
+ N = len(sk2r)
315
+ largest_divisor = 1
316
+ for i in range(1, int(math.sqrt(N))+1):
317
+ if N % i == 0:
318
+ largest_divisor = i
319
+ ncols = largest_divisor
320
+ nrows = N // ncols
321
+ figsize = figsize or (6 * ncols, 6 * nrows)
322
+
323
+ f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
324
+
325
+ groups = list(set(group_fn(result) for result in allresults))
326
+
327
+ default_samples = 512
328
+ if average_group:
329
+ resample = resample or default_samples
330
+
331
+ for (isplit, sk) in enumerate(sorted(sk2r.keys())):
332
+ g2l = {}
333
+ g2c = defaultdict(int)
334
+ sresults = sk2r[sk]
335
+ gresults = defaultdict(list)
336
+ idx_row = isplit // ncols
337
+ idx_col = isplit % ncols
338
+ ax = axarr[idx_row][idx_col]
339
+ for result in sresults:
340
+ group = group_fn(result)
341
+ g2c[group] += 1
342
+ x, y = xy_fn(result)
343
+ if x is None: x = np.arange(len(y))
344
+ x, y = map(np.asarray, (x, y))
345
+ if average_group:
346
+ gresults[group].append((x,y))
347
+ else:
348
+ if resample:
349
+ x, y, counts = symmetric_ema(x, y, x[0], x[-1], resample, decay_steps=smooth_step)
350
+ l, = ax.plot(x, y, color=COLORS[groups.index(group) % len(COLORS)])
351
+ g2l[group] = l
352
+ if average_group:
353
+ for group in sorted(groups):
354
+ xys = gresults[group]
355
+ if not any(xys):
356
+ continue
357
+ color = COLORS[groups.index(group) % len(COLORS)]
358
+ origxs = [xy[0] for xy in xys]
359
+ minxlen = min(map(len, origxs))
360
+ def allequal(qs):
361
+ return all((q==qs[0]).all() for q in qs[1:])
362
+ if resample:
363
+ low = max(x[0] for x in origxs)
364
+ high = min(x[-1] for x in origxs)
365
+ usex = np.linspace(low, high, resample)
366
+ ys = []
367
+ for (x, y) in xys:
368
+ ys.append(symmetric_ema(x, y, low, high, resample, decay_steps=smooth_step)[1])
369
+ else:
370
+ assert allequal([x[:minxlen] for x in origxs]),\
371
+ 'If you want to average unevenly sampled data, set resample=<number of samples you want>'
372
+ usex = origxs[0]
373
+ ys = [xy[1][:minxlen] for xy in xys]
374
+ ymean = np.mean(ys, axis=0)
375
+ ystd = np.std(ys, axis=0)
376
+ ystderr = ystd / np.sqrt(len(ys))
377
+ l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
378
+ g2l[group] = l
379
+ if shaded_err:
380
+ ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
381
+ if shaded_std:
382
+ ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2)
383
+
384
+
385
+ # https://matplotlib.org/users/legend_guide.html
386
+ plt.tight_layout()
387
+ if any(g2l.keys()):
388
+ ax.legend(
389
+ g2l.values(),
390
+ ['%s (%i)'%(g, g2c[g]) for g in g2l] if average_group else g2l.keys(),
391
+ loc=2 if legend_outside else None,
392
+ bbox_to_anchor=(1,1) if legend_outside else None)
393
+ ax.set_title(sk)
394
+ # add xlabels, but only to the bottom row
395
+ if xlabel is not None:
396
+ for ax in axarr[-1]:
397
+ plt.sca(ax)
398
+ plt.xlabel(xlabel)
399
+ # add ylabels, but only to left column
400
+ if ylabel is not None:
401
+ for ax in axarr[:,0]:
402
+ plt.sca(ax)
403
+ plt.ylabel(ylabel)
404
+
405
+ return f, axarr
406
+
407
+ def regression_analysis(df):
408
+ xcols = list(df.columns.copy())
409
+ xcols.remove('score')
410
+ ycols = ['score']
411
+ import statsmodels.api as sm
412
+ mod = sm.OLS(df[ycols], sm.add_constant(df[xcols]), hasconst=False)
413
+ res = mod.fit()
414
+ print(res.summary())
415
+
416
+ def test_smooth():
417
+ norig = 100
418
+ nup = 300
419
+ ndown = 30
420
+ xs = np.cumsum(np.random.rand(norig) * 10 / norig)
421
+ yclean = np.sin(xs)
422
+ ys = yclean + .1 * np.random.randn(yclean.size)
423
+ xup, yup, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown)
424
+ xdown, ydown, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown)
425
+ xsame, ysame, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown)
426
+ plt.plot(xs, ys, label='orig', marker='x')
427
+ plt.plot(xup, yup, label='up', marker='x')
428
+ plt.plot(xdown, ydown, label='down', marker='x')
429
+ plt.plot(xsame, ysame, label='same', marker='x')
430
+ plt.plot(xs, yclean, label='clean', marker='x')
431
+ plt.legend()
432
+ plt.show()
433
+
434
+
baselines/common/policies.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from baselines.common import tf_util
3
+ from baselines.a2c.utils import fc
4
+ from baselines.common.distributions import make_pdtype
5
+ from baselines.common.input import observation_placeholder, encode_observation
6
+ from baselines.common.tf_util import adjust_shape
7
+ from baselines.common.mpi_running_mean_std import RunningMeanStd
8
+ from baselines.common.models import get_network_builder
9
+
10
+ import gym
11
+
12
+
13
+ class PolicyWithValue(object):
14
+ """
15
+ Encapsulates fields and methods for RL policy and value function estimation with shared parameters
16
+ """
17
+
18
+ def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, **tensors):
19
+ """
20
+ Parameters:
21
+ ----------
22
+ env RL environment
23
+
24
+ observations tensorflow placeholder in which the observations will be fed
25
+
26
+ latent latent state from which policy distribution parameters should be inferred
27
+
28
+ vf_latent latent state from which value function should be inferred (if None, then latent is used)
29
+
30
+ sess tensorflow session to run calculations in (if None, default session is used)
31
+
32
+ **tensors tensorflow tensors for additional attributes such as state or mask
33
+
34
+ """
35
+
36
+ self.X = observations
37
+ self.state = tf.constant([])
38
+ self.initial_state = None
39
+ self.__dict__.update(tensors)
40
+
41
+ vf_latent = vf_latent if vf_latent is not None else latent
42
+
43
+ vf_latent = tf.compat.v1.layers.flatten(vf_latent)
44
+ latent = tf.compat.v1.layers.flatten(latent)
45
+
46
+ # Based on the action space, will select what probability distribution type
47
+ self.pdtype = make_pdtype(env.action_space)
48
+
49
+ self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
50
+
51
+ # Take an action
52
+ self.action = self.pd.sample()
53
+
54
+ # Calculate the neg log of our probability
55
+ self.neglogp = self.pd.neglogp(self.action)
56
+ self.sess = sess or tf.compat.v1.get_default_session()
57
+
58
+ if estimate_q:
59
+ assert isinstance(env.action_space, gym.spaces.Discrete)
60
+ self.q = fc(vf_latent, 'q', env.action_space.n)
61
+ self.vf = self.q
62
+ else:
63
+ self.vf = fc(vf_latent, 'vf', 1)
64
+ self.vf = self.vf[:,0]
65
+
66
+ def _evaluate(self, variables, observation, **extra_feed):
67
+ sess = self.sess
68
+ feed_dict = {self.X: adjust_shape(self.X, observation)}
69
+ for inpt_name, data in extra_feed.items():
70
+ if inpt_name in self.__dict__.keys():
71
+ inpt = self.__dict__[inpt_name]
72
+ if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder':
73
+ feed_dict[inpt] = adjust_shape(inpt, data)
74
+
75
+ return sess.run(variables, feed_dict)
76
+
77
+ def step(self, observation, **extra_feed):
78
+ """
79
+ Compute next action(s) given the observation(s)
80
+
81
+ Parameters:
82
+ ----------
83
+
84
+ observation observation data (either single or a batch)
85
+
86
+ **extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
87
+
88
+ Returns:
89
+ -------
90
+ (action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple
91
+ """
92
+
93
+ a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, **extra_feed)
94
+ if state.size == 0:
95
+ state = None
96
+ return a, v, state, neglogp
97
+
98
+ def value(self, ob, *args, **kwargs):
99
+ """
100
+ Compute value estimate(s) given the observation(s)
101
+
102
+ Parameters:
103
+ ----------
104
+
105
+ observation observation data (either single or a batch)
106
+
107
+ **extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
108
+
109
+ Returns:
110
+ -------
111
+ value estimate
112
+ """
113
+ return self._evaluate(self.vf, ob, *args, **kwargs)
114
+
115
+ def save(self, save_path):
116
+ tf_util.save_state(save_path, sess=self.sess)
117
+
118
+ def load(self, load_path):
119
+ tf_util.load_state(load_path, sess=self.sess)
120
+
121
+ def build_policy(env, policy_network, value_network=None, normalize_observations=False, estimate_q=False, **policy_kwargs):
122
+ if isinstance(policy_network, str):
123
+ network_type = policy_network
124
+ policy_network = get_network_builder(network_type)(**policy_kwargs)
125
+
126
+ def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
127
+ ob_space = env.observation_space
128
+
129
+ X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch)
130
+
131
+ extra_tensors = {}
132
+
133
+ if normalize_observations and X.dtype == tf.float32:
134
+ encoded_x, rms = _normalize_clip_observation(X)
135
+ extra_tensors['rms'] = rms
136
+ else:
137
+ encoded_x = X
138
+
139
+ encoded_x = encode_observation(ob_space, encoded_x)
140
+
141
+ with tf.compat.v1.variable_scope('pi', reuse=tf.compat.v1.AUTO_REUSE):
142
+ policy_latent = policy_network(encoded_x)
143
+ if isinstance(policy_latent, tuple):
144
+ policy_latent, recurrent_tensors = policy_latent
145
+
146
+ if recurrent_tensors is not None:
147
+ # recurrent architecture, need a few more steps
148
+ nenv = nbatch // nsteps
149
+ assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
150
+ policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
151
+ extra_tensors.update(recurrent_tensors)
152
+
153
+
154
+ _v_net = value_network
155
+
156
+ if _v_net is None or _v_net == 'shared':
157
+ vf_latent = policy_latent
158
+ else:
159
+ if _v_net == 'copy':
160
+ _v_net = policy_network
161
+ else:
162
+ assert callable(_v_net)
163
+
164
+ with tf.compat.v1.variable_scope('vf', reuse=tf.compat.v1.AUTO_REUSE):
165
+ # TODO recurrent architectures are not supported with value_network=copy yet
166
+ vf_latent = _v_net(encoded_x)
167
+
168
+ policy = PolicyWithValue(
169
+ env=env,
170
+ observations=X,
171
+ latent=policy_latent,
172
+ vf_latent=vf_latent,
173
+ sess=sess,
174
+ estimate_q=estimate_q,
175
+ **extra_tensors
176
+ )
177
+ return policy
178
+
179
+ return policy_fn
180
+
181
+
182
+ def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
183
+ rms = RunningMeanStd(shape=x.shape[1:])
184
+ norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
185
+ return norm_x, rms
186
+
baselines/common/retro_wrappers.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import cv2
3
+ cv2.ocl.setUseOpenCL(False)
4
+ from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
5
+ from .wrappers import TimeLimit
6
+ import numpy as np
7
+ import gym
8
+
9
+
10
+ class StochasticFrameSkip(gym.Wrapper):
11
+ def __init__(self, env, n, stickprob):
12
+ gym.Wrapper.__init__(self, env)
13
+ self.n = n
14
+ self.stickprob = stickprob
15
+ self.curac = None
16
+ self.rng = np.random.RandomState()
17
+ self.supports_want_render = hasattr(env, "supports_want_render")
18
+
19
+ def reset(self, **kwargs):
20
+ self.curac = None
21
+ return self.env.reset(**kwargs)
22
+
23
+ def step(self, ac):
24
+ done = False
25
+ totrew = 0
26
+ for i in range(self.n):
27
+ # First step after reset, use action
28
+ if self.curac is None:
29
+ self.curac = ac
30
+ # First substep, delay with probability=stickprob
31
+ elif i==0:
32
+ if self.rng.rand() > self.stickprob:
33
+ self.curac = ac
34
+ # Second substep, new action definitely kicks in
35
+ elif i==1:
36
+ self.curac = ac
37
+ if self.supports_want_render and i<self.n-1:
38
+ ob, rew, done, info = self.env.step(self.curac, want_render=False)
39
+ else:
40
+ ob, rew, done, info = self.env.step(self.curac)
41
+ totrew += rew
42
+ if done: break
43
+ return ob, totrew, done, info
44
+
45
+ def seed(self, s):
46
+ self.rng.seed(s)
47
+
48
+ class PartialFrameStack(gym.Wrapper):
49
+ def __init__(self, env, k, channel=1):
50
+ """
51
+ Stack one channel (channel keyword) from previous frames
52
+ """
53
+ gym.Wrapper.__init__(self, env)
54
+ shp = env.observation_space.shape
55
+ self.channel = channel
56
+ self.observation_space = gym.spaces.Box(low=0, high=255,
57
+ shape=(shp[0], shp[1], shp[2] + k - 1),
58
+ dtype=env.observation_space.dtype)
59
+ self.k = k
60
+ self.frames = deque([], maxlen=k)
61
+ shp = env.observation_space.shape
62
+
63
+ def reset(self):
64
+ ob = self.env.reset()
65
+ assert ob.shape[2] > self.channel
66
+ for _ in range(self.k):
67
+ self.frames.append(ob)
68
+ return self._get_ob()
69
+
70
+ def step(self, ac):
71
+ ob, reward, done, info = self.env.step(ac)
72
+ self.frames.append(ob)
73
+ return self._get_ob(), reward, done, info
74
+
75
+ def _get_ob(self):
76
+ assert len(self.frames) == self.k
77
+ return np.concatenate([frame if i==self.k-1 else frame[:,:,self.channel:self.channel+1]
78
+ for (i, frame) in enumerate(self.frames)], axis=2)
79
+
80
+ class Downsample(gym.ObservationWrapper):
81
+ def __init__(self, env, ratio):
82
+ """
83
+ Downsample images by a factor of ratio
84
+ """
85
+ gym.ObservationWrapper.__init__(self, env)
86
+ (oldh, oldw, oldc) = env.observation_space.shape
87
+ newshape = (oldh//ratio, oldw//ratio, oldc)
88
+ self.observation_space = gym.spaces.Box(low=0, high=255,
89
+ shape=newshape, dtype=np.uint8)
90
+
91
+ def observation(self, frame):
92
+ height, width, _ = self.observation_space.shape
93
+ frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
94
+ if frame.ndim == 2:
95
+ frame = frame[:,:,None]
96
+ return frame
97
+
98
+ class Rgb2gray(gym.ObservationWrapper):
99
+ def __init__(self, env):
100
+ """
101
+ Downsample images by a factor of ratio
102
+ """
103
+ gym.ObservationWrapper.__init__(self, env)
104
+ (oldh, oldw, _oldc) = env.observation_space.shape
105
+ self.observation_space = gym.spaces.Box(low=0, high=255,
106
+ shape=(oldh, oldw, 1), dtype=np.uint8)
107
+
108
+ def observation(self, frame):
109
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
110
+ return frame[:,:,None]
111
+
112
+
113
+ class MovieRecord(gym.Wrapper):
114
+ def __init__(self, env, savedir, k):
115
+ gym.Wrapper.__init__(self, env)
116
+ self.savedir = savedir
117
+ self.k = k
118
+ self.epcount = 0
119
+ def reset(self):
120
+ if self.epcount % self.k == 0:
121
+ self.env.unwrapped.movie_path = self.savedir
122
+ else:
123
+ self.env.unwrapped.movie_path = None
124
+ self.env.unwrapped.movie = None
125
+ self.epcount += 1
126
+ return self.env.reset()
127
+
128
+ class AppendTimeout(gym.Wrapper):
129
+ def __init__(self, env):
130
+ gym.Wrapper.__init__(self, env)
131
+ self.action_space = env.action_space
132
+ self.timeout_space = gym.spaces.Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32)
133
+ self.original_os = env.observation_space
134
+ if isinstance(self.original_os, gym.spaces.Dict):
135
+ import copy
136
+ ordered_dict = copy.deepcopy(self.original_os.spaces)
137
+ ordered_dict['value_estimation_timeout'] = self.timeout_space
138
+ self.observation_space = gym.spaces.Dict(ordered_dict)
139
+ self.dict_mode = True
140
+ else:
141
+ self.observation_space = gym.spaces.Dict({
142
+ 'original': self.original_os,
143
+ 'value_estimation_timeout': self.timeout_space
144
+ })
145
+ self.dict_mode = False
146
+ self.ac_count = None
147
+ while 1:
148
+ if not hasattr(env, "_max_episode_steps"): # Looking for TimeLimit wrapper that has this field
149
+ env = env.env
150
+ continue
151
+ break
152
+ self.timeout = env._max_episode_steps
153
+
154
+ def step(self, ac):
155
+ self.ac_count += 1
156
+ ob, rew, done, info = self.env.step(ac)
157
+ return self._process(ob), rew, done, info
158
+
159
+ def reset(self):
160
+ self.ac_count = 0
161
+ return self._process(self.env.reset())
162
+
163
+ def _process(self, ob):
164
+ fracmissing = 1 - self.ac_count / self.timeout
165
+ if self.dict_mode:
166
+ ob['value_estimation_timeout'] = fracmissing
167
+ else:
168
+ return { 'original': ob, 'value_estimation_timeout': fracmissing }
169
+
170
+ class StartDoingRandomActionsWrapper(gym.Wrapper):
171
+ """
172
+ Warning: can eat info dicts, not good if you depend on them
173
+ """
174
+ def __init__(self, env, max_random_steps, on_startup=True, every_episode=False):
175
+ gym.Wrapper.__init__(self, env)
176
+ self.on_startup = on_startup
177
+ self.every_episode = every_episode
178
+ self.random_steps = max_random_steps
179
+ self.last_obs = None
180
+ if on_startup:
181
+ self.some_random_steps()
182
+
183
+ def some_random_steps(self):
184
+ self.last_obs = self.env.reset()
185
+ n = np.random.randint(self.random_steps)
186
+ #print("running for random %i frames" % n)
187
+ for _ in range(n):
188
+ self.last_obs, _, done, _ = self.env.step(self.env.action_space.sample())
189
+ if done: self.last_obs = self.env.reset()
190
+
191
+ def reset(self):
192
+ return self.last_obs
193
+
194
+ def step(self, a):
195
+ self.last_obs, rew, done, info = self.env.step(a)
196
+ if done:
197
+ self.last_obs = self.env.reset()
198
+ if self.every_episode:
199
+ self.some_random_steps()
200
+ return self.last_obs, rew, done, info
201
+
202
+ def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
203
+ import retro
204
+ if state is None:
205
+ state = retro.State.DEFAULT
206
+ env = retro.make(game, state, **kwargs)
207
+ env = StochasticFrameSkip(env, n=4, stickprob=0.25)
208
+ if max_episode_steps is not None:
209
+ env = TimeLimit(env, max_episode_steps=max_episode_steps)
210
+ return env
211
+
212
+ def wrap_deepmind_retro(env, scale=True, frame_stack=4):
213
+ """
214
+ Configure environment for retro games, using config similar to DeepMind-style Atari in wrap_deepmind
215
+ """
216
+ env = WarpFrame(env)
217
+ env = ClipRewardEnv(env)
218
+ if frame_stack > 1:
219
+ env = FrameStack(env, frame_stack)
220
+ if scale:
221
+ env = ScaledFloatFrame(env)
222
+ return env
223
+
224
+ class SonicDiscretizer(gym.ActionWrapper):
225
+ """
226
+ Wrap a gym-retro environment and make it use discrete
227
+ actions for the Sonic game.
228
+ """
229
+ def __init__(self, env):
230
+ super(SonicDiscretizer, self).__init__(env)
231
+ buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]
232
+ actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
233
+ ['DOWN', 'B'], ['B']]
234
+ self._actions = []
235
+ for action in actions:
236
+ arr = np.array([False] * 12)
237
+ for button in action:
238
+ arr[buttons.index(button)] = True
239
+ self._actions.append(arr)
240
+ self.action_space = gym.spaces.Discrete(len(self._actions))
241
+
242
+ def action(self, a): # pylint: disable=W0221
243
+ return self._actions[a].copy()
244
+
245
+ class RewardScaler(gym.RewardWrapper):
246
+ """
247
+ Bring rewards to a reasonable scale for PPO.
248
+ This is incredibly important and effects performance
249
+ drastically.
250
+ """
251
+ def __init__(self, env, scale=0.01):
252
+ super(RewardScaler, self).__init__(env)
253
+ self.scale = scale
254
+
255
+ def reward(self, reward):
256
+ return reward * self.scale
257
+
258
+ class AllowBacktracking(gym.Wrapper):
259
+ """
260
+ Use deltas in max(X) as the reward, rather than deltas
261
+ in X. This way, agents are not discouraged too heavily
262
+ from exploring backwards if there is no way to advance
263
+ head-on in the level.
264
+ """
265
+ def __init__(self, env):
266
+ super(AllowBacktracking, self).__init__(env)
267
+ self._cur_x = 0
268
+ self._max_x = 0
269
+
270
+ def reset(self, **kwargs): # pylint: disable=E0202
271
+ self._cur_x = 0
272
+ self._max_x = 0
273
+ return self.env.reset(**kwargs)
274
+
275
+ def step(self, action): # pylint: disable=E0202
276
+ obs, rew, done, info = self.env.step(action)
277
+ self._cur_x += rew
278
+ rew = max(0, self._cur_x - self._max_x)
279
+ self._max_x = max(self._max_x, self._cur_x)
280
+ return obs, rew, done, info
baselines/common/runners.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from abc import ABC, abstractmethod
3
+
4
+ class AbstractEnvRunner(ABC):
5
+ def __init__(self, *, env, model, nsteps):
6
+ self.env = env
7
+ self.model = model
8
+ self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
9
+ self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
10
+ self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
11
+ self.obs[:] = env.reset()
12
+ self.nsteps = nsteps
13
+ self.states = model.initial_state
14
+ self.dones = [False for _ in range(nenv)]
15
+
16
+ @abstractmethod
17
+ def run(self):
18
+ raise NotImplementedError
19
+
baselines/common/running_mean_std.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from baselines.common.tf_util import get_session
4
+
5
+ class RunningMeanStd(object):
6
+ # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
7
+ def __init__(self, epsilon=1e-4, shape=()):
8
+ self.mean = np.zeros(shape, 'float64')
9
+ self.var = np.ones(shape, 'float64')
10
+ self.count = epsilon
11
+
12
+ def update(self, x):
13
+ batch_mean = np.mean(x, axis=0)
14
+ batch_var = np.var(x, axis=0)
15
+ batch_count = x.shape[0]
16
+ self.update_from_moments(batch_mean, batch_var, batch_count)
17
+
18
+ def update_from_moments(self, batch_mean, batch_var, batch_count):
19
+ self.mean, self.var, self.count = update_mean_var_count_from_moments(
20
+ self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
21
+
22
+ def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
23
+ delta = batch_mean - mean
24
+ tot_count = count + batch_count
25
+
26
+ new_mean = mean + delta * batch_count / tot_count
27
+ m_a = var * count
28
+ m_b = batch_var * batch_count
29
+ M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
30
+ new_var = M2 / tot_count
31
+ new_count = tot_count
32
+
33
+ return new_mean, new_var, new_count
34
+
35
+
36
+ class TfRunningMeanStd(object):
37
+ # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
38
+ '''
39
+ TensorFlow variables-based implmentation of computing running mean and std
40
+ Benefit of this implementation is that it can be saved / loaded together with the tensorflow model
41
+ '''
42
+ def __init__(self, epsilon=1e-4, shape=(), scope=''):
43
+ sess = get_session()
44
+
45
+ self._new_mean = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64)
46
+ self._new_var = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64)
47
+ self._new_count = tf.compat.v1.placeholder(shape=(), dtype=tf.float64)
48
+
49
+
50
+ with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
51
+ self._mean = tf.compat.v1.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
52
+ self._var = tf.compat.v1.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
53
+ self._count = tf.compat.v1.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
54
+
55
+ self.update_ops = tf.group([
56
+ self._var.assign(self._new_var),
57
+ self._mean.assign(self._new_mean),
58
+ self._count.assign(self._new_count)
59
+ ])
60
+
61
+ sess.run(tf.compat.v1.variables_initializer([self._mean, self._var, self._count]))
62
+ self.sess = sess
63
+ self._set_mean_var_count()
64
+
65
+ def _set_mean_var_count(self):
66
+ self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
67
+
68
+ def update(self, x):
69
+ batch_mean = np.mean(x, axis=0)
70
+ batch_var = np.var(x, axis=0)
71
+ batch_count = x.shape[0]
72
+
73
+ new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
74
+
75
+ self.sess.run(self.update_ops, feed_dict={
76
+ self._new_mean: new_mean,
77
+ self._new_var: new_var,
78
+ self._new_count: new_count
79
+ })
80
+
81
+ self._set_mean_var_count()
82
+
83
+
84
+
85
+ def test_runningmeanstd():
86
+ for (x1, x2, x3) in [
87
+ (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
88
+ (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
89
+ ]:
90
+
91
+ rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
92
+
93
+ x = np.concatenate([x1, x2, x3], axis=0)
94
+ ms1 = [x.mean(axis=0), x.var(axis=0)]
95
+ rms.update(x1)
96
+ rms.update(x2)
97
+ rms.update(x3)
98
+ ms2 = [rms.mean, rms.var]
99
+
100
+ np.testing.assert_allclose(ms1, ms2)
101
+
102
+ def test_tf_runningmeanstd():
103
+ for (x1, x2, x3) in [
104
+ (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
105
+ (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
106
+ ]:
107
+
108
+ rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128)))
109
+
110
+ x = np.concatenate([x1, x2, x3], axis=0)
111
+ ms1 = [x.mean(axis=0), x.var(axis=0)]
112
+ rms.update(x1)
113
+ rms.update(x2)
114
+ rms.update(x3)
115
+ ms2 = [rms.mean, rms.var]
116
+
117
+ np.testing.assert_allclose(ms1, ms2)
118
+
119
+
120
+ def profile_tf_runningmeanstd():
121
+ import time
122
+ from baselines.common import tf_util
123
+
124
+ tf_util.get_session( config=tf.compat.v1.ConfigProto(
125
+ inter_op_parallelism_threads=1,
126
+ intra_op_parallelism_threads=1,
127
+ allow_soft_placement=True
128
+ ))
129
+
130
+ x = np.random.random((376,))
131
+
132
+ n_trials = 10000
133
+ rms = RunningMeanStd()
134
+ tfrms = TfRunningMeanStd()
135
+
136
+ tic1 = time.time()
137
+ for _ in range(n_trials):
138
+ rms.update(x)
139
+
140
+ tic2 = time.time()
141
+ for _ in range(n_trials):
142
+ tfrms.update(x)
143
+
144
+ tic3 = time.time()
145
+
146
+ print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1))
147
+ print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2))
148
+
149
+
150
+ tic1 = time.time()
151
+ for _ in range(n_trials):
152
+ z1 = rms.mean
153
+
154
+ tic2 = time.time()
155
+ for _ in range(n_trials):
156
+ z2 = tfrms.mean
157
+
158
+ assert z1 == z2
159
+
160
+ tic3 = time.time()
161
+
162
+ print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1))
163
+ print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2))
164
+
165
+
166
+
167
+ '''
168
+ options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101
169
+ run_metadata = tf.RunMetadata()
170
+ profile_opts = dict(options=options, run_metadata=run_metadata)
171
+
172
+
173
+
174
+ from tensorflow.python.client import timeline
175
+ fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101
176
+ chrome_trace = fetched_timeline.generate_chrome_trace_format()
177
+ outfile = '/tmp/timeline.json'
178
+ with open(outfile, 'wt') as f:
179
+ f.write(chrome_trace)
180
+ print('Successfully saved profile to {}. Exiting.'.format(outfile))
181
+ exit(0)
182
+ '''
183
+
184
+
185
+
186
+ if __name__ == '__main__':
187
+ profile_tf_runningmeanstd()
baselines/common/schedules.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for specifying various schedules that evolve over
2
+ time throughout the execution of the algorithm, such as:
3
+ - learning rate for the optimizer
4
+ - exploration epsilon for the epsilon greedy exploration strategy
5
+ - beta parameter for beta parameter in prioritized replay
6
+
7
+ Each schedule has a function `value(t)` which returns the current value
8
+ of the parameter given the timestep t of the optimization procedure.
9
+ """
10
+
11
+
12
+ class Schedule(object):
13
+ def value(self, t):
14
+ """Value of the schedule at time t"""
15
+ raise NotImplementedError()
16
+
17
+
18
+ class ConstantSchedule(object):
19
+ def __init__(self, value):
20
+ """Value remains constant over time.
21
+
22
+ Parameters
23
+ ----------
24
+ value: float
25
+ Constant value of the schedule
26
+ """
27
+ self._v = value
28
+
29
+ def value(self, t):
30
+ """See Schedule.value"""
31
+ return self._v
32
+
33
+
34
+ def linear_interpolation(l, r, alpha):
35
+ return l + alpha * (r - l)
36
+
37
+
38
+ class PiecewiseSchedule(object):
39
+ def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
40
+ """Piecewise schedule.
41
+
42
+ endpoints: [(int, int)]
43
+ list of pairs `(time, value)` meanining that schedule should output
44
+ `value` when `t==time`. All the values for time must be sorted in
45
+ an increasing order. When t is between two times, e.g. `(time_a, value_a)`
46
+ and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
47
+ `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
48
+ time passed between `time_a` and `time_b` for time `t`.
49
+ interpolation: lambda float, float, float: float
50
+ a function that takes value to the left and to the right of t according
51
+ to the `endpoints`. Alpha is the fraction of distance from left endpoint to
52
+ right endpoint that t has covered. See linear_interpolation for example.
53
+ outside_value: float
54
+ if the value is requested outside of all the intervals sepecified in
55
+ `endpoints` this value is returned. If None then AssertionError is
56
+ raised when outside value is requested.
57
+ """
58
+ idxes = [e[0] for e in endpoints]
59
+ assert idxes == sorted(idxes)
60
+ self._interpolation = interpolation
61
+ self._outside_value = outside_value
62
+ self._endpoints = endpoints
63
+
64
+ def value(self, t):
65
+ """See Schedule.value"""
66
+ for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
67
+ if l_t <= t and t < r_t:
68
+ alpha = float(t - l_t) / (r_t - l_t)
69
+ return self._interpolation(l, r, alpha)
70
+
71
+ # t does not belong to any of the pieces, so doom.
72
+ assert self._outside_value is not None
73
+ return self._outside_value
74
+
75
+
76
+ class LinearSchedule(object):
77
+ def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
78
+ """Linear interpolation between initial_p and final_p over
79
+ schedule_timesteps. After this many timesteps pass final_p is
80
+ returned.
81
+
82
+ Parameters
83
+ ----------
84
+ schedule_timesteps: int
85
+ Number of timesteps for which to linearly anneal initial_p
86
+ to final_p
87
+ initial_p: float
88
+ initial output value
89
+ final_p: float
90
+ final output value
91
+ """
92
+ self.schedule_timesteps = schedule_timesteps
93
+ self.final_p = final_p
94
+ self.initial_p = initial_p
95
+
96
+ def value(self, t):
97
+ """See Schedule.value"""
98
+ fraction = min(float(t) / self.schedule_timesteps, 1.0)
99
+ return self.initial_p + fraction * (self.final_p - self.initial_p)