Upload 254 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +32 -0
- LICENSE +21 -0
- README.md +194 -12
- baselines/__init__.py +0 -0
- baselines/a2c/README.md +13 -0
- baselines/a2c/__init__.py +0 -0
- baselines/a2c/a2c.py +232 -0
- baselines/a2c/runner.py +76 -0
- baselines/a2c/utils.py +291 -0
- baselines/acer/README.md +6 -0
- baselines/acer/__init__.py +0 -0
- baselines/acer/acer.py +381 -0
- baselines/acer/buffer.py +156 -0
- baselines/acer/defaults.py +4 -0
- baselines/acer/policies.py +81 -0
- baselines/acer/runner.py +61 -0
- baselines/acktr/README.md +9 -0
- baselines/acktr/__init__.py +0 -0
- baselines/acktr/acktr.py +158 -0
- baselines/acktr/defaults.py +5 -0
- baselines/acktr/kfac.py +928 -0
- baselines/acktr/kfac_utils.py +86 -0
- baselines/acktr/utils.py +28 -0
- baselines/bench/__init__.py +3 -0
- baselines/bench/benchmarks.py +164 -0
- baselines/bench/monitor.py +162 -0
- baselines/bench/test_monitor.py +31 -0
- baselines/common/__init__.py +5 -0
- baselines/common/atari_wrappers.py +290 -0
- baselines/common/cg.py +34 -0
- baselines/common/cmd_util.py +206 -0
- baselines/common/console_util.py +80 -0
- baselines/common/dataset.py +60 -0
- baselines/common/distributions.py +355 -0
- baselines/common/input.py +65 -0
- baselines/common/math_util.py +85 -0
- baselines/common/misc_util.py +243 -0
- baselines/common/models.py +406 -0
- baselines/common/mpi_adam.py +103 -0
- baselines/common/mpi_adam_optimizer.py +90 -0
- baselines/common/mpi_fork.py +23 -0
- baselines/common/mpi_moments.py +61 -0
- baselines/common/mpi_running_mean_std.py +112 -0
- baselines/common/mpi_util.py +133 -0
- baselines/common/plot_util.py +434 -0
- baselines/common/policies.py +186 -0
- baselines/common/retro_wrappers.py +280 -0
- baselines/common/runners.py +19 -0
- baselines/common/running_mean_std.py +187 -0
- baselines/common/schedules.py +99 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,35 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
baselines/gail/result/halfcheetah-training.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
baselines/gail/result/hopper-training.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
baselines/gail/result/humanoid-training.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
baselines/gail/result/humanoidstandup-training.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
baselines/gail/result/walker2d-training.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
train_procgen/checkpoints/sppo-fruitbot_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
train_procgen/checkpoints/sppo-jumper_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
train_procgen/checkpoints/sppo-ninja_easy_0_0_2021/checkpoints/25000000 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
videos/fruitbot_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
videos/fruitbot_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
videos/fruitbot_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
videos/fruitbot_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
videos/fruitbot_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
videos/fruitbot_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
videos/fruitbot_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
videos/fruitbot_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
videos/jumper_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
videos/jumper_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
videos/jumper_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
videos/jumper_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
videos/jumper_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
videos/jumper_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
videos/jumper_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
videos/jumper_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
videos/ninja_skills/cluster_0.gif filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
videos/ninja_skills/cluster_1.gif filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
videos/ninja_skills/cluster_2.gif filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
videos/ninja_skills/cluster_3.gif filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
videos/ninja_skills/cluster_4.gif filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
videos/ninja_skills/cluster_5.gif filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
videos/ninja_skills/cluster_6.gif filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
videos/ninja_skills/cluster_7.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2019 OpenAI (http://openai.com)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in
|
| 13 |
+
all copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 21 |
+
THE SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,194 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
- clustering
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Enhancing Interpretability in Deep Reinforcement Learning through Semantic Clustering
|
| 2 |
+
|
| 3 |
+
**Authors:** Liang Zhang, Justin Lieffers, Adarsh Pyarelal
|
| 4 |
+
**Conference:** NeurIPS 2025 Main Track
|
| 5 |
+
**Paper:** [arXiv:2409.17411](https://arxiv.org/abs/2409.17411)
|
| 6 |
+
|
| 7 |
+
This repository contains the official implementation of our research on enhancing interpretability in deep reinforcement learning through semantic clustering techniques. Our work extends the [OpenAI train-procgen](https://github.com/openai/train-procgen) framework to incorporate semantic clustering methods for improved understanding and visualization of learned policies in procedural environments.
|
| 8 |
+
|
| 9 |
+
## 📋 Abstract
|
| 10 |
+
|
| 11 |
+
This work presents a novel approach to enhancing interpretability in deep reinforcement learning by leveraging semantic clustering techniques. We demonstrate how semantic clustering can provide insights into learned policies, enabling better understanding of agent behavior and decision-making processes in complex procedural environments.
|
| 12 |
+
|
| 13 |
+
## 🚀 Quick Start
|
| 14 |
+
|
| 15 |
+
### Installation
|
| 16 |
+
|
| 17 |
+
Prerequisite: Python 3.8.
|
| 18 |
+
|
| 19 |
+
1. **Clone the repository:**
|
| 20 |
+
```bash
|
| 21 |
+
git clone https://github.com/ualiangzhang/semantic_rl.git
|
| 22 |
+
cd semantic_rl
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
2. **Install dependencies (Python 3.8):**
|
| 26 |
+
```bash
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
3. **Install Procgen environments:**
|
| 31 |
+
Follow the installation steps in the [Procgen repository](https://github.com/openai/procgen).
|
| 32 |
+
|
| 33 |
+
### Basic Usage
|
| 34 |
+
|
| 35 |
+
**Train a semantic clustering model:**
|
| 36 |
+
```bash
|
| 37 |
+
python -m train_procgen.train_sppo --env_name <ENV_NAME> --num_levels 0 --distribution_mode easy --timesteps_per_proc 25000000 --rand_seed <RAND_SEED>
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**Train a baseline model:**
|
| 41 |
+
```bash
|
| 42 |
+
python -m train_procgen.train_ppo --env_name <ENV_NAME> --num_levels 0 --distribution_mode easy --timesteps_per_proc 25000000 --rand_seed <RAND_SEED>
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## 📊 Visualization and Analysis
|
| 46 |
+
|
| 47 |
+
### Performance Analysis
|
| 48 |
+
|
| 49 |
+
**Generate generalization figures for a single game:**
|
| 50 |
+
```bash
|
| 51 |
+
cd train_procgen
|
| 52 |
+
python single_graph.py --env_name <ENV_NAME>
|
| 53 |
+
# Example:
|
| 54 |
+
python single_graph.py --env_name coinrun
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Semantic Clustering Visualization
|
| 58 |
+
|
| 59 |
+
**Generate embedding space visualizations:**
|
| 60 |
+
```bash
|
| 61 |
+
python -m train_procgen.enjoy_sppo --env_name <ENV_NAME> --mode 1
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**Generate skill demonstration videos:**
|
| 65 |
+
```bash
|
| 66 |
+
python -m train_procgen.enjoy_sppo --env_name <ENV_NAME> --mode 0
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
**Interactive cluster exploration:**
|
| 70 |
+
```bash
|
| 71 |
+
python -m train_procgen.hover_clusters --env_name <ENV_NAME>
|
| 72 |
+
# Example:
|
| 73 |
+
python -m train_procgen.hover_clusters --env_name fruitbot
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## 🎮 Supported Environments
|
| 77 |
+
|
| 78 |
+
Our implementation supports four Procgen environments:
|
| 79 |
+
- **CoinRun**
|
| 80 |
+
- **FruitBot**
|
| 81 |
+
- **Jumper**
|
| 82 |
+
- **Ninja**
|
| 83 |
+
|
| 84 |
+
## 🎬 Semantic Clustering Demonstration
|
| 85 |
+
|
| 86 |
+
### Ninja Environment - 8 Semantic Clusters
|
| 87 |
+
|
| 88 |
+
The following videos demonstrate the 8 distinct semantic clusters learned by our model in the Ninja environment. Each cluster represents a different behavioral pattern and skill set:
|
| 89 |
+
|
| 90 |
+
#### 📹 Semantic Cluster Demonstrations
|
| 91 |
+
|
| 92 |
+
<table>
|
| 93 |
+
<tr>
|
| 94 |
+
<td align="center">
|
| 95 |
+
<strong>Cluster 0</strong><br/>
|
| 96 |
+
<img src="videos/ninja_skills/cluster_0.gif" width="180" height="180" />
|
| 97 |
+
</td>
|
| 98 |
+
<td align="center">
|
| 99 |
+
<strong>Cluster 1</strong><br/>
|
| 100 |
+
<img src="videos/ninja_skills/cluster_1.gif" width="180" height="180" />
|
| 101 |
+
</td>
|
| 102 |
+
<td align="center">
|
| 103 |
+
<strong>Cluster 2</strong><br/>
|
| 104 |
+
<img src="videos/ninja_skills/cluster_2.gif" width="180" height="180" />
|
| 105 |
+
</td>
|
| 106 |
+
<td align="center">
|
| 107 |
+
<strong>Cluster 3</strong><br/>
|
| 108 |
+
<img src="videos/ninja_skills/cluster_3.gif" width="180" height="180" />
|
| 109 |
+
</td>
|
| 110 |
+
</tr>
|
| 111 |
+
<tr>
|
| 112 |
+
<td align="center">
|
| 113 |
+
<strong>Cluster 4</strong><br/>
|
| 114 |
+
<img src="videos/ninja_skills/cluster_4.gif" width="180" height="180" />
|
| 115 |
+
</td>
|
| 116 |
+
<td align="center">
|
| 117 |
+
<strong>Cluster 5</strong><br/>
|
| 118 |
+
<img src="videos/ninja_skills/cluster_5.gif" width="180" height="180" />
|
| 119 |
+
</td>
|
| 120 |
+
<td align="center">
|
| 121 |
+
<strong>Cluster 6</strong><br/>
|
| 122 |
+
<img src="videos/ninja_skills/cluster_6.gif" width="180" height="180" />
|
| 123 |
+
</td>
|
| 124 |
+
<td align="center">
|
| 125 |
+
<strong>Cluster 7</strong><br/>
|
| 126 |
+
<img src="videos/ninja_skills/cluster_7.gif" width="180" height="180" />
|
| 127 |
+
</td>
|
| 128 |
+
</tr>
|
| 129 |
+
</table>
|
| 130 |
+
|
| 131 |
+
#### 🧭 Behavior Descriptions (Ninja)
|
| 132 |
+
|
| 133 |
+
| Cluster | Behavior |
|
| 134 |
+
|---------|----------|
|
| 135 |
+
| 0 | The agent starts by walking through the first platform and then performs a high jump to reach a higher ledge. |
|
| 136 |
+
| 1 | The agent makes small jumps in the middle of the scene. |
|
| 137 |
+
| 2 | Two interpretations are present: (1) the agent starts from the leftmost end of the scene and walks to the starting position of Cluster 0; (2) when there are no higher ledges to jump to, the agent begins from the scene, walks over the first platform, and prepares to jump to the subsequent ledge. |
|
| 138 |
+
| 3 | The agent walks on the ledge and prepares to jump to a higher ledge. |
|
| 139 |
+
| 4 | After performing a high jump, the agent loses sight of the ledge below. |
|
| 140 |
+
| 5 | The agent walks on the ledge and prepares to jump onto a ledge at the same height or lower. |
|
| 141 |
+
| 6 | The agent executes a high jump while keeping the ledge below in sight. |
|
| 142 |
+
| 7 | The agent moves towards the right edge of the scene and touches the mushroom. |
|
| 143 |
+
|
| 144 |
+
#### 📊 Alternative: Generate Your Own Videos
|
| 145 |
+
|
| 146 |
+
You can also generate these videos yourself using our code:
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
# Generate Ninja skill cluster videos
|
| 150 |
+
python -m train_procgen.enjoy_sppo --env_name ninja --mode 0 --num_embeddings 8
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
**Note:** These videos showcase the distinct behavioral patterns learned by our semantic clustering approach. Each cluster demonstrates different combat strategies, movement patterns, and decision-making processes in the Ninja environment.
|
| 154 |
+
|
| 155 |
+
## 📁 Output Structure
|
| 156 |
+
|
| 157 |
+
```
|
| 158 |
+
baseline/ # Required RL training package
|
| 159 |
+
train_procgen/
|
| 160 |
+
├── checkpoints/ # Trained model checkpoints
|
| 161 |
+
├── figures/ # Generated visualizations and videos
|
| 162 |
+
videos/ # video clips corresponding to the clusters in the paper
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## 📈 Reproducing Results
|
| 166 |
+
|
| 167 |
+
To reproduce the results from our paper:
|
| 168 |
+
|
| 169 |
+
1. **(Optional) Use existing checkpoints**: We have provided pre-trained checkpoints for Ninja, FruitBot, and Jumper (random seed 2021) in this repository under `train_procgen/checkpoints/`. You can skip training and directly run the visualization scripts. Otherwise, **train models** using the commands above.
|
| 170 |
+
2. **Generate visualizations** using the provided scripts
|
| 171 |
+
3. **Analyze results** using the interactive tools
|
| 172 |
+
|
| 173 |
+
**Note:** Video generation may take 30-60 minutes depending on machine performance, as it ensures comprehensive exploration of all clusters.
|
| 174 |
+
|
| 175 |
+
<!-- ## 🤝 Citation
|
| 176 |
+
|
| 177 |
+
If you use this code in your research, please cite our paper: -->
|
| 178 |
+
|
| 179 |
+
<!-- ```bibtex
|
| 180 |
+
@article{zhang2025enhancing,
|
| 181 |
+
title={Enhancing Interpretability in Deep Reinforcement Learning through Semantic Clustering},
|
| 182 |
+
author={Zhang, Liang and Lieffers, Justin and Pyarelal, Adarsh},
|
| 183 |
+
journal={Advances in Neural Information Processing Systems},
|
| 184 |
+
year={2025}
|
| 185 |
+
}
|
| 186 |
+
``` -->
|
| 187 |
+
|
| 188 |
+
## 📄 License
|
| 189 |
+
|
| 190 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 191 |
+
|
| 192 |
+
## 🙏 Acknowledgments
|
| 193 |
+
|
| 194 |
+
This work builds upon the [OpenAI train-procgen](https://github.com/openai/train-procgen) framework. We thank the original authors for their excellent work on procedural generation for reinforcement learning benchmarking.
|
baselines/__init__.py
ADDED
|
File without changes
|
baselines/a2c/README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# A2C
|
| 2 |
+
|
| 3 |
+
- Original paper: https://arxiv.org/abs/1602.01783
|
| 4 |
+
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
|
| 5 |
+
- `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options
|
| 6 |
+
- also refer to the repo-wide [README.md](../../README.md#training-models)
|
| 7 |
+
|
| 8 |
+
## Files
|
| 9 |
+
- `run_atari`: file used to run the algorithm.
|
| 10 |
+
- `policies.py`: contains the different versions of the A2C architecture (MlpPolicy, CNNPolicy, LstmPolicy...).
|
| 11 |
+
- `a2c.py`: - Model : class used to initialize the step_model (sampling) and train_model (training)
|
| 12 |
+
- learn : Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
|
| 13 |
+
- `runner.py`: class used to generates a batch of experiences
|
baselines/a2c/__init__.py
ADDED
|
File without changes
|
baselines/a2c/a2c.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import functools
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
from baselines import logger
|
| 6 |
+
|
| 7 |
+
from baselines.common import set_global_seeds, explained_variance
|
| 8 |
+
from baselines.common import tf_util
|
| 9 |
+
from baselines.common.policies import build_policy
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
| 13 |
+
from baselines.a2c.runner import Runner
|
| 14 |
+
from baselines.ppo2.ppo2 import safemean
|
| 15 |
+
from collections import deque
|
| 16 |
+
|
| 17 |
+
from tensorflow import losses
|
| 18 |
+
|
| 19 |
+
class Model(object):
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
We use this class to :
|
| 23 |
+
__init__:
|
| 24 |
+
- Creates the step_model
|
| 25 |
+
- Creates the train_model
|
| 26 |
+
|
| 27 |
+
train():
|
| 28 |
+
- Make the training part (feedforward and retropropagation of gradients)
|
| 29 |
+
|
| 30 |
+
save/load():
|
| 31 |
+
- Save load the model
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, policy, env, nsteps,
|
| 34 |
+
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
|
| 35 |
+
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
|
| 36 |
+
|
| 37 |
+
sess = tf_util.get_session()
|
| 38 |
+
nenvs = env.num_envs
|
| 39 |
+
nbatch = nenvs*nsteps
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
with tf.compat.v1.variable_scope('a2c_model', reuse=tf.compat.v1.AUTO_REUSE):
|
| 43 |
+
# step_model is used for sampling
|
| 44 |
+
step_model = policy(nenvs, 1, sess)
|
| 45 |
+
|
| 46 |
+
# train_model is used to train our network
|
| 47 |
+
train_model = policy(nbatch, nsteps, sess)
|
| 48 |
+
|
| 49 |
+
A = tf.compat.v1.placeholder(train_model.action.dtype, train_model.action.shape)
|
| 50 |
+
ADV = tf.compat.v1.placeholder(tf.float32, [nbatch])
|
| 51 |
+
R = tf.compat.v1.placeholder(tf.float32, [nbatch])
|
| 52 |
+
LR = tf.compat.v1.placeholder(tf.float32, [])
|
| 53 |
+
|
| 54 |
+
# Calculate the loss
|
| 55 |
+
# Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss
|
| 56 |
+
|
| 57 |
+
# Policy loss
|
| 58 |
+
neglogpac = train_model.pd.neglogp(A)
|
| 59 |
+
# L = A(s,a) * -logpi(a|s)
|
| 60 |
+
pg_loss = tf.reduce_mean(input_tensor=ADV * neglogpac)
|
| 61 |
+
|
| 62 |
+
# Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy.
|
| 63 |
+
entropy = tf.reduce_mean(input_tensor=train_model.pd.entropy())
|
| 64 |
+
|
| 65 |
+
# Value loss
|
| 66 |
+
vf_loss = losses.mean_squared_error(tf.squeeze(train_model.vf), R)
|
| 67 |
+
|
| 68 |
+
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
|
| 69 |
+
|
| 70 |
+
# Update parameters using loss
|
| 71 |
+
# 1. Get the model parameters
|
| 72 |
+
params = find_trainable_variables("a2c_model")
|
| 73 |
+
|
| 74 |
+
# 2. Calculate the gradients
|
| 75 |
+
grads = tf.gradients(ys=loss, xs=params)
|
| 76 |
+
if max_grad_norm is not None:
|
| 77 |
+
# Clip the gradients (normalize)
|
| 78 |
+
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
|
| 79 |
+
grads = list(zip(grads, params))
|
| 80 |
+
# zip aggregate each gradient with parameters associated
|
| 81 |
+
# For instance zip(ABCD, xyza) => Ax, By, Cz, Da
|
| 82 |
+
|
| 83 |
+
# 3. Make op for one policy and value update step of A2C
|
| 84 |
+
trainer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
|
| 85 |
+
|
| 86 |
+
_train = trainer.apply_gradients(grads)
|
| 87 |
+
|
| 88 |
+
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
|
| 89 |
+
|
| 90 |
+
def train(obs, states, rewards, masks, actions, values):
|
| 91 |
+
# Here we calculate advantage A(s,a) = R + yV(s') - V(s)
|
| 92 |
+
# rewards = R + yV(s')
|
| 93 |
+
advs = rewards - values
|
| 94 |
+
for step in range(len(obs)):
|
| 95 |
+
cur_lr = lr.value()
|
| 96 |
+
|
| 97 |
+
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
|
| 98 |
+
if states is not None:
|
| 99 |
+
td_map[train_model.S] = states
|
| 100 |
+
td_map[train_model.M] = masks
|
| 101 |
+
policy_loss, value_loss, policy_entropy, _ = sess.run(
|
| 102 |
+
[pg_loss, vf_loss, entropy, _train],
|
| 103 |
+
td_map
|
| 104 |
+
)
|
| 105 |
+
return policy_loss, value_loss, policy_entropy
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
self.train = train
|
| 109 |
+
self.train_model = train_model
|
| 110 |
+
self.step_model = step_model
|
| 111 |
+
self.step = step_model.step
|
| 112 |
+
self.value = step_model.value
|
| 113 |
+
self.initial_state = step_model.initial_state
|
| 114 |
+
self.save = functools.partial(tf_util.save_variables, sess=sess)
|
| 115 |
+
self.load = functools.partial(tf_util.load_variables, sess=sess)
|
| 116 |
+
tf.compat.v1.global_variables_initializer().run(session=sess)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def learn(
|
| 120 |
+
network,
|
| 121 |
+
env,
|
| 122 |
+
seed=None,
|
| 123 |
+
nsteps=5,
|
| 124 |
+
total_timesteps=int(80e6),
|
| 125 |
+
vf_coef=0.5,
|
| 126 |
+
ent_coef=0.01,
|
| 127 |
+
max_grad_norm=0.5,
|
| 128 |
+
lr=7e-4,
|
| 129 |
+
lrschedule='linear',
|
| 130 |
+
epsilon=1e-5,
|
| 131 |
+
alpha=0.99,
|
| 132 |
+
gamma=0.99,
|
| 133 |
+
log_interval=100,
|
| 134 |
+
load_path=None,
|
| 135 |
+
**network_kwargs):
|
| 136 |
+
|
| 137 |
+
'''
|
| 138 |
+
Main entrypoint for A2C algorithm. Train a policy with given network architecture on a given environment using a2c algorithm.
|
| 139 |
+
|
| 140 |
+
Parameters:
|
| 141 |
+
-----------
|
| 142 |
+
|
| 143 |
+
network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
|
| 144 |
+
specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
|
| 145 |
+
tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
|
| 146 |
+
neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
|
| 147 |
+
See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
env: RL environment. Should implement interface similar to VecEnv (baselines.common/vec_env) or be wrapped with DummyVecEnv (baselines.common/vec_env/dummy_vec_env.py)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
seed: seed to make random number sequence in the alorightm reproducible. By default is None which means seed from system noise generator (not reproducible)
|
| 154 |
+
|
| 155 |
+
nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
|
| 156 |
+
nenv is number of environment copies simulated in parallel)
|
| 157 |
+
|
| 158 |
+
total_timesteps: int, total number of timesteps to train on (default: 80M)
|
| 159 |
+
|
| 160 |
+
vf_coef: float, coefficient in front of value function loss in the total loss function (default: 0.5)
|
| 161 |
+
|
| 162 |
+
ent_coef: float, coeffictiant in front of the policy entropy in the total loss function (default: 0.01)
|
| 163 |
+
|
| 164 |
+
max_gradient_norm: float, gradient is clipped to have global L2 norm no more than this value (default: 0.5)
|
| 165 |
+
|
| 166 |
+
lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
|
| 167 |
+
|
| 168 |
+
lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
|
| 169 |
+
returns fraction of the learning rate (specified as lr) as output
|
| 170 |
+
|
| 171 |
+
epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
|
| 172 |
+
|
| 173 |
+
alpha: float, RMSProp decay parameter (default: 0.99)
|
| 174 |
+
|
| 175 |
+
gamma: float, reward discounting parameter (default: 0.99)
|
| 176 |
+
|
| 177 |
+
log_interval: int, specifies how frequently the logs are printed out (default: 100)
|
| 178 |
+
|
| 179 |
+
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
|
| 180 |
+
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
|
| 181 |
+
|
| 182 |
+
'''
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
set_global_seeds(seed)
|
| 187 |
+
|
| 188 |
+
# Get the nb of env
|
| 189 |
+
nenvs = env.num_envs
|
| 190 |
+
policy = build_policy(env, network, **network_kwargs)
|
| 191 |
+
|
| 192 |
+
# Instantiate the model object (that creates step_model and train_model)
|
| 193 |
+
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
|
| 194 |
+
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
|
| 195 |
+
if load_path is not None:
|
| 196 |
+
model.load(load_path)
|
| 197 |
+
|
| 198 |
+
# Instantiate the runner object
|
| 199 |
+
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
| 200 |
+
epinfobuf = deque(maxlen=100)
|
| 201 |
+
|
| 202 |
+
# Calculate the batch_size
|
| 203 |
+
nbatch = nenvs*nsteps
|
| 204 |
+
|
| 205 |
+
# Start total timer
|
| 206 |
+
tstart = time.time()
|
| 207 |
+
|
| 208 |
+
for update in range(1, total_timesteps//nbatch+1):
|
| 209 |
+
# Get mini batch of experiences
|
| 210 |
+
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
| 211 |
+
epinfobuf.extend(epinfos)
|
| 212 |
+
|
| 213 |
+
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
| 214 |
+
nseconds = time.time()-tstart
|
| 215 |
+
|
| 216 |
+
# Calculate the fps (frame per second)
|
| 217 |
+
fps = int((update*nbatch)/nseconds)
|
| 218 |
+
if update % log_interval == 0 or update == 1:
|
| 219 |
+
# Calculates if value function is a good predicator of the returns (ev > 1)
|
| 220 |
+
# or if it's just worse than predicting nothing (ev =< 0)
|
| 221 |
+
ev = explained_variance(values, rewards)
|
| 222 |
+
logger.record_tabular("nupdates", update)
|
| 223 |
+
logger.record_tabular("total_timesteps", update*nbatch)
|
| 224 |
+
logger.record_tabular("fps", fps)
|
| 225 |
+
logger.record_tabular("policy_entropy", float(policy_entropy))
|
| 226 |
+
logger.record_tabular("value_loss", float(value_loss))
|
| 227 |
+
logger.record_tabular("explained_variance", float(ev))
|
| 228 |
+
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
| 229 |
+
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
| 230 |
+
logger.dump_tabular()
|
| 231 |
+
return model
|
| 232 |
+
|
baselines/a2c/runner.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from baselines.a2c.utils import discount_with_dones
|
| 3 |
+
from baselines.common.runners import AbstractEnvRunner
|
| 4 |
+
|
| 5 |
+
class Runner(AbstractEnvRunner):
|
| 6 |
+
"""
|
| 7 |
+
We use this class to generate batches of experiences
|
| 8 |
+
|
| 9 |
+
__init__:
|
| 10 |
+
- Initialize the runner
|
| 11 |
+
|
| 12 |
+
run():
|
| 13 |
+
- Make a mini batch of experiences
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, env, model, nsteps=5, gamma=0.99):
|
| 16 |
+
super().__init__(env=env, model=model, nsteps=nsteps)
|
| 17 |
+
self.gamma = gamma
|
| 18 |
+
self.batch_action_shape = [x if x is not None else -1 for x in model.train_model.action.shape.as_list()]
|
| 19 |
+
self.ob_dtype = model.train_model.X.dtype.as_numpy_dtype
|
| 20 |
+
|
| 21 |
+
def run(self):
|
| 22 |
+
# We initialize the lists that will contain the mb of experiences
|
| 23 |
+
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
| 24 |
+
mb_states = self.states
|
| 25 |
+
epinfos = []
|
| 26 |
+
for n in range(self.nsteps):
|
| 27 |
+
# Given observations, take action and value (V(s))
|
| 28 |
+
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
| 29 |
+
actions, values, states, _ = self.model.step(self.obs, S=self.states, M=self.dones)
|
| 30 |
+
|
| 31 |
+
# Append the experiences
|
| 32 |
+
mb_obs.append(np.copy(self.obs))
|
| 33 |
+
mb_actions.append(actions)
|
| 34 |
+
mb_values.append(values)
|
| 35 |
+
mb_dones.append(self.dones)
|
| 36 |
+
|
| 37 |
+
# Take actions in env and look the results
|
| 38 |
+
obs, rewards, dones, infos = self.env.step(actions)
|
| 39 |
+
for info in infos:
|
| 40 |
+
maybeepinfo = info.get('episode')
|
| 41 |
+
if maybeepinfo: epinfos.append(maybeepinfo)
|
| 42 |
+
self.states = states
|
| 43 |
+
self.dones = dones
|
| 44 |
+
self.obs = obs
|
| 45 |
+
mb_rewards.append(rewards)
|
| 46 |
+
mb_dones.append(self.dones)
|
| 47 |
+
|
| 48 |
+
# Batch of steps to batch of rollouts
|
| 49 |
+
mb_obs = np.asarray(mb_obs, dtype=self.ob_dtype).swapaxes(1, 0).reshape(self.batch_ob_shape)
|
| 50 |
+
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
| 51 |
+
mb_actions = np.asarray(mb_actions, dtype=self.model.train_model.action.dtype.name).swapaxes(1, 0)
|
| 52 |
+
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
|
| 53 |
+
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
|
| 54 |
+
mb_masks = mb_dones[:, :-1]
|
| 55 |
+
mb_dones = mb_dones[:, 1:]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if self.gamma > 0.0:
|
| 59 |
+
# Discount/bootstrap off value fn
|
| 60 |
+
last_values = self.model.value(self.obs, S=self.states, M=self.dones).tolist()
|
| 61 |
+
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
|
| 62 |
+
rewards = rewards.tolist()
|
| 63 |
+
dones = dones.tolist()
|
| 64 |
+
if dones[-1] == 0:
|
| 65 |
+
rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
|
| 66 |
+
else:
|
| 67 |
+
rewards = discount_with_dones(rewards, dones, self.gamma)
|
| 68 |
+
|
| 69 |
+
mb_rewards[n] = rewards
|
| 70 |
+
|
| 71 |
+
mb_actions = mb_actions.reshape(self.batch_action_shape)
|
| 72 |
+
|
| 73 |
+
mb_rewards = mb_rewards.flatten()
|
| 74 |
+
mb_values = mb_values.flatten()
|
| 75 |
+
mb_masks = mb_masks.flatten()
|
| 76 |
+
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
|
baselines/a2c/utils.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
def sample(logits):
|
| 7 |
+
noise = tf.random.uniform(tf.shape(input=logits))
|
| 8 |
+
return tf.argmax(input=logits - tf.math.log(-tf.math.log(noise)), axis=1)
|
| 9 |
+
|
| 10 |
+
def cat_entropy(logits):
|
| 11 |
+
a0 = logits - tf.reduce_max(input_tensor=logits, axis=1, keepdims=True)
|
| 12 |
+
ea0 = tf.exp(a0)
|
| 13 |
+
z0 = tf.reduce_sum(input_tensor=ea0, axis=1, keepdims=True)
|
| 14 |
+
p0 = ea0 / z0
|
| 15 |
+
return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=1)
|
| 16 |
+
|
| 17 |
+
def cat_entropy_softmax(p0):
|
| 18 |
+
return - tf.reduce_sum(input_tensor=p0 * tf.math.log(p0 + 1e-6), axis = 1)
|
| 19 |
+
|
| 20 |
+
def ortho_init(scale=1.0):
|
| 21 |
+
def _ortho_init(shape, dtype, partition_info=None):
|
| 22 |
+
#lasagne ortho init for tf
|
| 23 |
+
shape = tuple(shape)
|
| 24 |
+
if len(shape) == 2:
|
| 25 |
+
flat_shape = shape
|
| 26 |
+
elif len(shape) == 4: # assumes NHWC
|
| 27 |
+
flat_shape = (np.prod(shape[:-1]), shape[-1])
|
| 28 |
+
else:
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
a = np.random.normal(0.0, 1.0, flat_shape)
|
| 31 |
+
u, _, v = np.linalg.svd(a, full_matrices=False)
|
| 32 |
+
q = u if u.shape == flat_shape else v # pick the one with the correct shape
|
| 33 |
+
q = q.reshape(shape)
|
| 34 |
+
return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
|
| 35 |
+
return _ortho_init
|
| 36 |
+
|
| 37 |
+
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
|
| 38 |
+
if data_format == 'NHWC':
|
| 39 |
+
channel_ax = 3
|
| 40 |
+
strides = [1, stride, stride, 1]
|
| 41 |
+
bshape = [1, 1, 1, nf]
|
| 42 |
+
elif data_format == 'NCHW':
|
| 43 |
+
channel_ax = 1
|
| 44 |
+
strides = [1, 1, stride, stride]
|
| 45 |
+
bshape = [1, nf, 1, 1]
|
| 46 |
+
else:
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
|
| 49 |
+
try:
|
| 50 |
+
nin = x.get_shape()[channel_ax].value
|
| 51 |
+
except:
|
| 52 |
+
nin = x.get_shape()[channel_ax]
|
| 53 |
+
wshape = [rf, rf, nin, nf]
|
| 54 |
+
with tf.compat.v1.variable_scope(scope):
|
| 55 |
+
w = tf.compat.v1.get_variable("w", wshape, initializer=ortho_init(init_scale))
|
| 56 |
+
b = tf.compat.v1.get_variable("b", bias_var_shape, initializer=tf.compat.v1.constant_initializer(0.0))
|
| 57 |
+
if not one_dim_bias and data_format == 'NHWC':
|
| 58 |
+
b = tf.reshape(b, bshape)
|
| 59 |
+
return tf.nn.conv2d(input=x, filters=w, strides=strides, padding=pad, data_format=data_format) + b
|
| 60 |
+
|
| 61 |
+
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
|
| 62 |
+
with tf.compat.v1.variable_scope(scope):
|
| 63 |
+
try:
|
| 64 |
+
nin = x.get_shape()[1].value
|
| 65 |
+
except:
|
| 66 |
+
nin = x.get_shape()[1]
|
| 67 |
+
w = tf.compat.v1.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
|
| 68 |
+
b = tf.compat.v1.get_variable("b", [nh], initializer=tf.compat.v1.constant_initializer(init_bias))
|
| 69 |
+
return tf.matmul(x, w)+b
|
| 70 |
+
|
| 71 |
+
def batch_to_seq(h, nbatch, nsteps, flat=False):
|
| 72 |
+
if flat:
|
| 73 |
+
h = tf.reshape(h, [nbatch, nsteps])
|
| 74 |
+
else:
|
| 75 |
+
h = tf.reshape(h, [nbatch, nsteps, -1])
|
| 76 |
+
return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
|
| 77 |
+
|
| 78 |
+
def seq_to_batch(h, flat = False):
|
| 79 |
+
shape = h[0].get_shape().as_list()
|
| 80 |
+
if not flat:
|
| 81 |
+
assert(len(shape) > 1)
|
| 82 |
+
nh = h[0].get_shape()[-1].value
|
| 83 |
+
return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
|
| 84 |
+
else:
|
| 85 |
+
return tf.reshape(tf.stack(values=h, axis=1), [-1])
|
| 86 |
+
|
| 87 |
+
def lstm(xs, ms, s, scope, nh, init_scale=1.0):
|
| 88 |
+
nbatch, nin = [v.value for v in xs[0].get_shape()]
|
| 89 |
+
with tf.compat.v1.variable_scope(scope):
|
| 90 |
+
wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
|
| 91 |
+
wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
|
| 92 |
+
b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
|
| 93 |
+
|
| 94 |
+
c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
|
| 95 |
+
for idx, (x, m) in enumerate(zip(xs, ms)):
|
| 96 |
+
c = c*(1-m)
|
| 97 |
+
h = h*(1-m)
|
| 98 |
+
z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
|
| 99 |
+
i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
|
| 100 |
+
i = tf.nn.sigmoid(i)
|
| 101 |
+
f = tf.nn.sigmoid(f)
|
| 102 |
+
o = tf.nn.sigmoid(o)
|
| 103 |
+
u = tf.tanh(u)
|
| 104 |
+
c = f*c + i*u
|
| 105 |
+
h = o*tf.tanh(c)
|
| 106 |
+
xs[idx] = h
|
| 107 |
+
s = tf.concat(axis=1, values=[c, h])
|
| 108 |
+
return xs, s
|
| 109 |
+
|
| 110 |
+
def _ln(x, g, b, e=1e-5, axes=[1]):
|
| 111 |
+
u, s = tf.nn.moments(x=x, axes=axes, keepdims=True)
|
| 112 |
+
x = (x-u)/tf.sqrt(s+e)
|
| 113 |
+
x = x*g+b
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
def lnlstm(xs, ms, s, scope, nh, init_scale=1.0):
|
| 117 |
+
nbatch, nin = [v.value for v in xs[0].get_shape()]
|
| 118 |
+
with tf.compat.v1.variable_scope(scope):
|
| 119 |
+
wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
|
| 120 |
+
gx = tf.compat.v1.get_variable("gx", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0))
|
| 121 |
+
bx = tf.compat.v1.get_variable("bx", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
|
| 122 |
+
|
| 123 |
+
wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
|
| 124 |
+
gh = tf.compat.v1.get_variable("gh", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0))
|
| 125 |
+
bh = tf.compat.v1.get_variable("bh", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
|
| 126 |
+
|
| 127 |
+
b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0))
|
| 128 |
+
|
| 129 |
+
gc = tf.compat.v1.get_variable("gc", [nh], initializer=tf.compat.v1.constant_initializer(1.0))
|
| 130 |
+
bc = tf.compat.v1.get_variable("bc", [nh], initializer=tf.compat.v1.constant_initializer(0.0))
|
| 131 |
+
|
| 132 |
+
c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
|
| 133 |
+
for idx, (x, m) in enumerate(zip(xs, ms)):
|
| 134 |
+
c = c*(1-m)
|
| 135 |
+
h = h*(1-m)
|
| 136 |
+
z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b
|
| 137 |
+
i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
|
| 138 |
+
i = tf.nn.sigmoid(i)
|
| 139 |
+
f = tf.nn.sigmoid(f)
|
| 140 |
+
o = tf.nn.sigmoid(o)
|
| 141 |
+
u = tf.tanh(u)
|
| 142 |
+
c = f*c + i*u
|
| 143 |
+
h = o*tf.tanh(_ln(c, gc, bc))
|
| 144 |
+
xs[idx] = h
|
| 145 |
+
s = tf.concat(axis=1, values=[c, h])
|
| 146 |
+
return xs, s
|
| 147 |
+
|
| 148 |
+
def conv_to_fc(x):
|
| 149 |
+
try:
|
| 150 |
+
nh = np.prod([v.value for v in x.get_shape()[1:]])
|
| 151 |
+
except:
|
| 152 |
+
nh = np.prod([v for v in x.get_shape()[1:]])
|
| 153 |
+
x = tf.reshape(x, [-1, nh])
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
def discount_with_dones(rewards, dones, gamma):
|
| 157 |
+
discounted = []
|
| 158 |
+
r = 0
|
| 159 |
+
for reward, done in zip(rewards[::-1], dones[::-1]):
|
| 160 |
+
r = reward + gamma*r*(1.-done) # fixed off by one bug
|
| 161 |
+
discounted.append(r)
|
| 162 |
+
return discounted[::-1]
|
| 163 |
+
|
| 164 |
+
def find_trainable_variables(key):
|
| 165 |
+
return tf.compat.v1.trainable_variables(key)
|
| 166 |
+
|
| 167 |
+
def make_path(f):
|
| 168 |
+
return os.makedirs(f, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
def constant(p):
|
| 171 |
+
return 1
|
| 172 |
+
|
| 173 |
+
def linear(p):
|
| 174 |
+
return 1-p
|
| 175 |
+
|
| 176 |
+
def middle_drop(p):
|
| 177 |
+
eps = 0.75
|
| 178 |
+
if 1-p<eps:
|
| 179 |
+
return eps*0.1
|
| 180 |
+
return 1-p
|
| 181 |
+
|
| 182 |
+
def double_linear_con(p):
|
| 183 |
+
p *= 2
|
| 184 |
+
eps = 0.125
|
| 185 |
+
if 1-p<eps:
|
| 186 |
+
return eps
|
| 187 |
+
return 1-p
|
| 188 |
+
|
| 189 |
+
def double_middle_drop(p):
|
| 190 |
+
eps1 = 0.75
|
| 191 |
+
eps2 = 0.25
|
| 192 |
+
if 1-p<eps1:
|
| 193 |
+
if 1-p<eps2:
|
| 194 |
+
return eps2*0.5
|
| 195 |
+
return eps1*0.1
|
| 196 |
+
return 1-p
|
| 197 |
+
|
| 198 |
+
schedules = {
|
| 199 |
+
'linear':linear,
|
| 200 |
+
'constant':constant,
|
| 201 |
+
'double_linear_con': double_linear_con,
|
| 202 |
+
'middle_drop': middle_drop,
|
| 203 |
+
'double_middle_drop': double_middle_drop
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
class Scheduler(object):
|
| 207 |
+
|
| 208 |
+
def __init__(self, v, nvalues, schedule):
|
| 209 |
+
self.n = 0.
|
| 210 |
+
self.v = v
|
| 211 |
+
self.nvalues = nvalues
|
| 212 |
+
self.schedule = schedules[schedule]
|
| 213 |
+
|
| 214 |
+
def value(self):
|
| 215 |
+
current_value = self.v*self.schedule(self.n/self.nvalues)
|
| 216 |
+
self.n += 1.
|
| 217 |
+
return current_value
|
| 218 |
+
|
| 219 |
+
def value_steps(self, steps):
|
| 220 |
+
return self.v*self.schedule(steps/self.nvalues)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class EpisodeStats:
|
| 224 |
+
def __init__(self, nsteps, nenvs):
|
| 225 |
+
self.episode_rewards = []
|
| 226 |
+
for i in range(nenvs):
|
| 227 |
+
self.episode_rewards.append([])
|
| 228 |
+
self.lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
|
| 229 |
+
self.rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
|
| 230 |
+
self.nsteps = nsteps
|
| 231 |
+
self.nenvs = nenvs
|
| 232 |
+
|
| 233 |
+
def feed(self, rewards, masks):
|
| 234 |
+
rewards = np.reshape(rewards, [self.nenvs, self.nsteps])
|
| 235 |
+
masks = np.reshape(masks, [self.nenvs, self.nsteps])
|
| 236 |
+
for i in range(0, self.nenvs):
|
| 237 |
+
for j in range(0, self.nsteps):
|
| 238 |
+
self.episode_rewards[i].append(rewards[i][j])
|
| 239 |
+
if masks[i][j]:
|
| 240 |
+
l = len(self.episode_rewards[i])
|
| 241 |
+
s = sum(self.episode_rewards[i])
|
| 242 |
+
self.lenbuffer.append(l)
|
| 243 |
+
self.rewbuffer.append(s)
|
| 244 |
+
self.episode_rewards[i] = []
|
| 245 |
+
|
| 246 |
+
def mean_length(self):
|
| 247 |
+
if self.lenbuffer:
|
| 248 |
+
return np.mean(self.lenbuffer)
|
| 249 |
+
else:
|
| 250 |
+
return 0 # on the first params dump, no episodes are finished
|
| 251 |
+
|
| 252 |
+
def mean_reward(self):
|
| 253 |
+
if self.rewbuffer:
|
| 254 |
+
return np.mean(self.rewbuffer)
|
| 255 |
+
else:
|
| 256 |
+
return 0
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# For ACER
|
| 260 |
+
def get_by_index(x, idx):
|
| 261 |
+
assert(len(x.get_shape()) == 2)
|
| 262 |
+
assert(len(idx.get_shape()) == 1)
|
| 263 |
+
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
|
| 264 |
+
y = tf.gather(tf.reshape(x, [-1]), # flatten input
|
| 265 |
+
idx_flattened) # use flattened indices
|
| 266 |
+
return y
|
| 267 |
+
|
| 268 |
+
def check_shape(ts,shapes):
|
| 269 |
+
i = 0
|
| 270 |
+
for (t,shape) in zip(ts,shapes):
|
| 271 |
+
assert t.get_shape().as_list()==shape, "id " + str(i) + " shape " + str(t.get_shape()) + str(shape)
|
| 272 |
+
i += 1
|
| 273 |
+
|
| 274 |
+
def avg_norm(t):
|
| 275 |
+
return tf.reduce_mean(input_tensor=tf.sqrt(tf.reduce_sum(input_tensor=tf.square(t), axis=-1)))
|
| 276 |
+
|
| 277 |
+
def gradient_add(g1, g2, param):
|
| 278 |
+
print([g1, g2, param.name])
|
| 279 |
+
assert (not (g1 is None and g2 is None)), param.name
|
| 280 |
+
if g1 is None:
|
| 281 |
+
return g2
|
| 282 |
+
elif g2 is None:
|
| 283 |
+
return g1
|
| 284 |
+
else:
|
| 285 |
+
return g1 + g2
|
| 286 |
+
|
| 287 |
+
def q_explained_variance(qpred, q):
|
| 288 |
+
_, vary = tf.nn.moments(x=q, axes=[0, 1])
|
| 289 |
+
_, varpred = tf.nn.moments(x=q - qpred, axes=[0, 1])
|
| 290 |
+
check_shape([vary, varpred], [[]] * 2)
|
| 291 |
+
return 1.0 - (varpred / vary)
|
baselines/acer/README.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACER
|
| 2 |
+
|
| 3 |
+
- Original paper: https://arxiv.org/abs/1611.01224
|
| 4 |
+
- `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
|
| 5 |
+
- also refer to the repo-wide [README.md](../../README.md#training-models)
|
| 6 |
+
|
baselines/acer/__init__.py
ADDED
|
File without changes
|
baselines/acer/acer.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import functools
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from baselines import logger
|
| 6 |
+
|
| 7 |
+
from baselines.common import set_global_seeds
|
| 8 |
+
from baselines.common.policies import build_policy
|
| 9 |
+
from baselines.common.tf_util import get_session, save_variables, load_variables
|
| 10 |
+
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
| 11 |
+
|
| 12 |
+
from baselines.a2c.utils import batch_to_seq, seq_to_batch
|
| 13 |
+
from baselines.a2c.utils import cat_entropy_softmax
|
| 14 |
+
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
| 15 |
+
from baselines.a2c.utils import EpisodeStats
|
| 16 |
+
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
|
| 17 |
+
from baselines.acer.buffer import Buffer
|
| 18 |
+
from baselines.acer.runner import Runner
|
| 19 |
+
|
| 20 |
+
# remove last step
|
| 21 |
+
def strip(var, nenvs, nsteps, flat = False):
|
| 22 |
+
vars = batch_to_seq(var, nenvs, nsteps + 1, flat)
|
| 23 |
+
return seq_to_batch(vars[:-1], flat)
|
| 24 |
+
|
| 25 |
+
def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma):
|
| 26 |
+
"""
|
| 27 |
+
Calculates q_retrace targets
|
| 28 |
+
|
| 29 |
+
:param R: Rewards
|
| 30 |
+
:param D: Dones
|
| 31 |
+
:param q_i: Q values for actions taken
|
| 32 |
+
:param v: V values
|
| 33 |
+
:param rho_i: Importance weight for each action
|
| 34 |
+
:return: Q_retrace values
|
| 35 |
+
"""
|
| 36 |
+
rho_bar = batch_to_seq(tf.minimum(1.0, rho_i), nenvs, nsteps, True) # list of len steps, shape [nenvs]
|
| 37 |
+
rs = batch_to_seq(R, nenvs, nsteps, True) # list of len steps, shape [nenvs]
|
| 38 |
+
ds = batch_to_seq(D, nenvs, nsteps, True) # list of len steps, shape [nenvs]
|
| 39 |
+
q_is = batch_to_seq(q_i, nenvs, nsteps, True)
|
| 40 |
+
vs = batch_to_seq(v, nenvs, nsteps + 1, True)
|
| 41 |
+
v_final = vs[-1]
|
| 42 |
+
qret = v_final
|
| 43 |
+
qrets = []
|
| 44 |
+
for i in range(nsteps - 1, -1, -1):
|
| 45 |
+
check_shape([qret, ds[i], rs[i], rho_bar[i], q_is[i], vs[i]], [[nenvs]] * 6)
|
| 46 |
+
qret = rs[i] + gamma * qret * (1.0 - ds[i])
|
| 47 |
+
qrets.append(qret)
|
| 48 |
+
qret = (rho_bar[i] * (qret - q_is[i])) + vs[i]
|
| 49 |
+
qrets = qrets[::-1]
|
| 50 |
+
qret = seq_to_batch(qrets, flat=True)
|
| 51 |
+
return qret
|
| 52 |
+
|
| 53 |
+
# For ACER with PPO clipping instead of trust region
|
| 54 |
+
# def clip(ratio, eps_clip):
|
| 55 |
+
# # assume 0 <= eps_clip <= 1
|
| 56 |
+
# return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
|
| 57 |
+
|
| 58 |
+
class Model(object):
|
| 59 |
+
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr,
|
| 60 |
+
rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
|
| 61 |
+
c, trust_region, alpha, delta):
|
| 62 |
+
|
| 63 |
+
sess = get_session()
|
| 64 |
+
nact = ac_space.n
|
| 65 |
+
nbatch = nenvs * nsteps
|
| 66 |
+
|
| 67 |
+
A = tf.compat.v1.placeholder(tf.int32, [nbatch]) # actions
|
| 68 |
+
D = tf.compat.v1.placeholder(tf.float32, [nbatch]) # dones
|
| 69 |
+
R = tf.compat.v1.placeholder(tf.float32, [nbatch]) # rewards, not returns
|
| 70 |
+
MU = tf.compat.v1.placeholder(tf.float32, [nbatch, nact]) # mu's
|
| 71 |
+
LR = tf.compat.v1.placeholder(tf.float32, [])
|
| 72 |
+
eps = 1e-6
|
| 73 |
+
|
| 74 |
+
step_ob_placeholder = tf.compat.v1.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape)
|
| 75 |
+
train_ob_placeholder = tf.compat.v1.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
|
| 76 |
+
with tf.compat.v1.variable_scope('acer_model', reuse=tf.compat.v1.AUTO_REUSE):
|
| 77 |
+
|
| 78 |
+
step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess)
|
| 79 |
+
train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
params = find_trainable_variables("acer_model")
|
| 83 |
+
print("Params {}".format(len(params)))
|
| 84 |
+
for var in params:
|
| 85 |
+
print(var)
|
| 86 |
+
|
| 87 |
+
# create polyak averaged model
|
| 88 |
+
ema = tf.train.ExponentialMovingAverage(alpha)
|
| 89 |
+
ema_apply_op = ema.apply(params)
|
| 90 |
+
|
| 91 |
+
def custom_getter(getter, *args, **kwargs):
|
| 92 |
+
v = ema.average(getter(*args, **kwargs))
|
| 93 |
+
print(v.name)
|
| 94 |
+
return v
|
| 95 |
+
|
| 96 |
+
with tf.compat.v1.variable_scope("acer_model", custom_getter=custom_getter, reuse=True):
|
| 97 |
+
polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
| 98 |
+
|
| 99 |
+
# Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i
|
| 100 |
+
|
| 101 |
+
# action probability distributions according to train_model, polyak_model and step_model
|
| 102 |
+
# poilcy.pi is probability distribution parameters; to obtain distribution that sums to 1 need to take softmax
|
| 103 |
+
train_model_p = tf.nn.softmax(train_model.pi)
|
| 104 |
+
polyak_model_p = tf.nn.softmax(polyak_model.pi)
|
| 105 |
+
step_model_p = tf.nn.softmax(step_model.pi)
|
| 106 |
+
v = tf.reduce_sum(input_tensor=train_model_p * train_model.q, axis = -1) # shape is [nenvs * (nsteps + 1)]
|
| 107 |
+
|
| 108 |
+
# strip off last step
|
| 109 |
+
f, f_pol, q = map(lambda var: strip(var, nenvs, nsteps), [train_model_p, polyak_model_p, train_model.q])
|
| 110 |
+
# Get pi and q values for actions taken
|
| 111 |
+
f_i = get_by_index(f, A)
|
| 112 |
+
q_i = get_by_index(q, A)
|
| 113 |
+
|
| 114 |
+
# Compute ratios for importance truncation
|
| 115 |
+
rho = f / (MU + eps)
|
| 116 |
+
rho_i = get_by_index(rho, A)
|
| 117 |
+
|
| 118 |
+
# Calculate Q_retrace targets
|
| 119 |
+
qret = q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma)
|
| 120 |
+
|
| 121 |
+
# Calculate losses
|
| 122 |
+
# Entropy
|
| 123 |
+
# entropy = tf.reduce_mean(strip(train_model.pd.entropy(), nenvs, nsteps))
|
| 124 |
+
entropy = tf.reduce_mean(input_tensor=cat_entropy_softmax(f))
|
| 125 |
+
|
| 126 |
+
# Policy Graident loss, with truncated importance sampling & bias correction
|
| 127 |
+
v = strip(v, nenvs, nsteps, True)
|
| 128 |
+
check_shape([qret, v, rho_i, f_i], [[nenvs * nsteps]] * 4)
|
| 129 |
+
check_shape([rho, f, q], [[nenvs * nsteps, nact]] * 2)
|
| 130 |
+
|
| 131 |
+
# Truncated importance sampling
|
| 132 |
+
adv = qret - v
|
| 133 |
+
logf = tf.math.log(f_i + eps)
|
| 134 |
+
gain_f = logf * tf.stop_gradient(adv * tf.minimum(c, rho_i)) # [nenvs * nsteps]
|
| 135 |
+
loss_f = -tf.reduce_mean(input_tensor=gain_f)
|
| 136 |
+
|
| 137 |
+
# Bias correction for the truncation
|
| 138 |
+
adv_bc = (q - tf.reshape(v, [nenvs * nsteps, 1])) # [nenvs * nsteps, nact]
|
| 139 |
+
logf_bc = tf.math.log(f + eps) # / (f_old + eps)
|
| 140 |
+
check_shape([adv_bc, logf_bc], [[nenvs * nsteps, nact]]*2)
|
| 141 |
+
gain_bc = tf.reduce_sum(input_tensor=logf_bc * tf.stop_gradient(adv_bc * tf.nn.relu(1.0 - (c / (rho + eps))) * f), axis = 1) #IMP: This is sum, as expectation wrt f
|
| 142 |
+
loss_bc= -tf.reduce_mean(input_tensor=gain_bc)
|
| 143 |
+
|
| 144 |
+
loss_policy = loss_f + loss_bc
|
| 145 |
+
|
| 146 |
+
# Value/Q function loss, and explained variance
|
| 147 |
+
check_shape([qret, q_i], [[nenvs * nsteps]]*2)
|
| 148 |
+
ev = q_explained_variance(tf.reshape(q_i, [nenvs, nsteps]), tf.reshape(qret, [nenvs, nsteps]))
|
| 149 |
+
loss_q = tf.reduce_mean(input_tensor=tf.square(tf.stop_gradient(qret) - q_i)*0.5)
|
| 150 |
+
|
| 151 |
+
# Net loss
|
| 152 |
+
check_shape([loss_policy, loss_q, entropy], [[]] * 3)
|
| 153 |
+
loss = loss_policy + q_coef * loss_q - ent_coef * entropy
|
| 154 |
+
|
| 155 |
+
if trust_region:
|
| 156 |
+
g = tf.gradients(ys=- (loss_policy - ent_coef * entropy) * nsteps * nenvs, xs=f) #[nenvs * nsteps, nact]
|
| 157 |
+
# k = tf.gradients(KL(f_pol || f), f)
|
| 158 |
+
k = - f_pol / (f + eps) #[nenvs * nsteps, nact] # Directly computed gradient of KL divergence wrt f
|
| 159 |
+
k_dot_g = tf.reduce_sum(input_tensor=k * g, axis=-1)
|
| 160 |
+
adj = tf.maximum(0.0, (tf.reduce_sum(input_tensor=k * g, axis=-1) - delta) / (tf.reduce_sum(input_tensor=tf.square(k), axis=-1) + eps)) #[nenvs * nsteps]
|
| 161 |
+
|
| 162 |
+
# Calculate stats (before doing adjustment) for logging.
|
| 163 |
+
avg_norm_k = avg_norm(k)
|
| 164 |
+
avg_norm_g = avg_norm(g)
|
| 165 |
+
avg_norm_k_dot_g = tf.reduce_mean(input_tensor=tf.abs(k_dot_g))
|
| 166 |
+
avg_norm_adj = tf.reduce_mean(input_tensor=tf.abs(adj))
|
| 167 |
+
|
| 168 |
+
g = g - tf.reshape(adj, [nenvs * nsteps, 1]) * k
|
| 169 |
+
grads_f = -g/(nenvs*nsteps) # These are turst region adjusted gradients wrt f ie statistics of policy pi
|
| 170 |
+
grads_policy = tf.gradients(ys=f, xs=params, grad_ys=grads_f)
|
| 171 |
+
grads_q = tf.gradients(ys=loss_q * q_coef, xs=params)
|
| 172 |
+
grads = [gradient_add(g1, g2, param) for (g1, g2, param) in zip(grads_policy, grads_q, params)]
|
| 173 |
+
|
| 174 |
+
avg_norm_grads_f = avg_norm(grads_f) * (nsteps * nenvs)
|
| 175 |
+
norm_grads_q = tf.linalg.global_norm(grads_q)
|
| 176 |
+
norm_grads_policy = tf.linalg.global_norm(grads_policy)
|
| 177 |
+
else:
|
| 178 |
+
grads = tf.gradients(ys=loss, xs=params)
|
| 179 |
+
|
| 180 |
+
if max_grad_norm is not None:
|
| 181 |
+
grads, norm_grads = tf.clip_by_global_norm(grads, max_grad_norm)
|
| 182 |
+
grads = list(zip(grads, params))
|
| 183 |
+
trainer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=LR, decay=rprop_alpha, epsilon=rprop_epsilon)
|
| 184 |
+
_opt_op = trainer.apply_gradients(grads)
|
| 185 |
+
|
| 186 |
+
# so when you call _train, you first do the gradient step, then you apply ema
|
| 187 |
+
with tf.control_dependencies([_opt_op]):
|
| 188 |
+
_train = tf.group(ema_apply_op)
|
| 189 |
+
|
| 190 |
+
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
|
| 191 |
+
|
| 192 |
+
# Ops/Summaries to run, and their names for logging
|
| 193 |
+
run_ops = [_train, loss, loss_q, entropy, loss_policy, loss_f, loss_bc, ev, norm_grads]
|
| 194 |
+
names_ops = ['loss', 'loss_q', 'entropy', 'loss_policy', 'loss_f', 'loss_bc', 'explained_variance',
|
| 195 |
+
'norm_grads']
|
| 196 |
+
if trust_region:
|
| 197 |
+
run_ops = run_ops + [norm_grads_q, norm_grads_policy, avg_norm_grads_f, avg_norm_k, avg_norm_g, avg_norm_k_dot_g,
|
| 198 |
+
avg_norm_adj]
|
| 199 |
+
names_ops = names_ops + ['norm_grads_q', 'norm_grads_policy', 'avg_norm_grads_f', 'avg_norm_k', 'avg_norm_g',
|
| 200 |
+
'avg_norm_k_dot_g', 'avg_norm_adj']
|
| 201 |
+
|
| 202 |
+
def train(obs, actions, rewards, dones, mus, states, masks, steps):
|
| 203 |
+
cur_lr = lr.value_steps(steps)
|
| 204 |
+
td_map = {train_model.X: obs, polyak_model.X: obs, A: actions, R: rewards, D: dones, MU: mus, LR: cur_lr}
|
| 205 |
+
if states is not None:
|
| 206 |
+
td_map[train_model.S] = states
|
| 207 |
+
td_map[train_model.M] = masks
|
| 208 |
+
td_map[polyak_model.S] = states
|
| 209 |
+
td_map[polyak_model.M] = masks
|
| 210 |
+
|
| 211 |
+
return names_ops, sess.run(run_ops, td_map)[1:] # strip off _train
|
| 212 |
+
|
| 213 |
+
def _step(observation, **kwargs):
|
| 214 |
+
return step_model._evaluate([step_model.action, step_model_p, step_model.state], observation, **kwargs)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
self.train = train
|
| 219 |
+
self.save = functools.partial(save_variables, sess=sess)
|
| 220 |
+
self.load = functools.partial(load_variables, sess=sess)
|
| 221 |
+
self.train_model = train_model
|
| 222 |
+
self.step_model = step_model
|
| 223 |
+
self._step = _step
|
| 224 |
+
self.step = self.step_model.step
|
| 225 |
+
|
| 226 |
+
self.initial_state = step_model.initial_state
|
| 227 |
+
tf.compat.v1.global_variables_initializer().run(session=sess)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class Acer():
|
| 231 |
+
def __init__(self, runner, model, buffer, log_interval):
|
| 232 |
+
self.runner = runner
|
| 233 |
+
self.model = model
|
| 234 |
+
self.buffer = buffer
|
| 235 |
+
self.log_interval = log_interval
|
| 236 |
+
self.tstart = None
|
| 237 |
+
self.episode_stats = EpisodeStats(runner.nsteps, runner.nenv)
|
| 238 |
+
self.steps = None
|
| 239 |
+
|
| 240 |
+
def call(self, on_policy):
|
| 241 |
+
runner, model, buffer, steps = self.runner, self.model, self.buffer, self.steps
|
| 242 |
+
if on_policy:
|
| 243 |
+
enc_obs, obs, actions, rewards, mus, dones, masks = runner.run()
|
| 244 |
+
self.episode_stats.feed(rewards, dones)
|
| 245 |
+
if buffer is not None:
|
| 246 |
+
buffer.put(enc_obs, actions, rewards, mus, dones, masks)
|
| 247 |
+
else:
|
| 248 |
+
# get obs, actions, rewards, mus, dones from buffer.
|
| 249 |
+
obs, actions, rewards, mus, dones, masks = buffer.get()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# reshape stuff correctly
|
| 253 |
+
obs = obs.reshape(runner.batch_ob_shape)
|
| 254 |
+
actions = actions.reshape([runner.nbatch])
|
| 255 |
+
rewards = rewards.reshape([runner.nbatch])
|
| 256 |
+
mus = mus.reshape([runner.nbatch, runner.nact])
|
| 257 |
+
dones = dones.reshape([runner.nbatch])
|
| 258 |
+
masks = masks.reshape([runner.batch_ob_shape[0]])
|
| 259 |
+
|
| 260 |
+
names_ops, values_ops = model.train(obs, actions, rewards, dones, mus, model.initial_state, masks, steps)
|
| 261 |
+
|
| 262 |
+
if on_policy and (int(steps/runner.nbatch) % self.log_interval == 0):
|
| 263 |
+
logger.record_tabular("total_timesteps", steps)
|
| 264 |
+
logger.record_tabular("fps", int(steps/(time.time() - self.tstart)))
|
| 265 |
+
# IMP: In EpisodicLife env, during training, we get done=True at each loss of life, not just at the terminal state.
|
| 266 |
+
# Thus, this is mean until end of life, not end of episode.
|
| 267 |
+
# For true episode rewards, see the monitor files in the log folder.
|
| 268 |
+
logger.record_tabular("mean_episode_length", self.episode_stats.mean_length())
|
| 269 |
+
logger.record_tabular("mean_episode_reward", self.episode_stats.mean_reward())
|
| 270 |
+
for name, val in zip(names_ops, values_ops):
|
| 271 |
+
logger.record_tabular(name, float(val))
|
| 272 |
+
logger.dump_tabular()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
|
| 276 |
+
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
|
| 277 |
+
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
|
| 278 |
+
trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
|
| 279 |
+
|
| 280 |
+
'''
|
| 281 |
+
Main entrypoint for ACER (Actor-Critic with Experience Replay) algorithm (https://arxiv.org/pdf/1611.01224.pdf)
|
| 282 |
+
Train an agent with given network architecture on a given environment using ACER.
|
| 283 |
+
|
| 284 |
+
Parameters:
|
| 285 |
+
----------
|
| 286 |
+
|
| 287 |
+
network: policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
|
| 288 |
+
specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
|
| 289 |
+
tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
|
| 290 |
+
neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
|
| 291 |
+
See baselines.common/policies.py/lstm for more details on using recurrent nets in policies
|
| 292 |
+
|
| 293 |
+
env: environment. Needs to be vectorized for parallel environment simulation.
|
| 294 |
+
The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.
|
| 295 |
+
|
| 296 |
+
nsteps: int, number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
|
| 297 |
+
nenv is number of environment copies simulated in parallel) (default: 20)
|
| 298 |
+
|
| 299 |
+
nstack: int, size of the frame stack, i.e. number of the frames passed to the step model. Frames are stacked along channel dimension
|
| 300 |
+
(last image dimension) (default: 4)
|
| 301 |
+
|
| 302 |
+
total_timesteps: int, number of timesteps (i.e. number of actions taken in the environment) (default: 80M)
|
| 303 |
+
|
| 304 |
+
q_coef: float, value function loss coefficient in the optimization objective (analog of vf_coef for other actor-critic methods)
|
| 305 |
+
|
| 306 |
+
ent_coef: float, policy entropy coefficient in the optimization objective (default: 0.01)
|
| 307 |
+
|
| 308 |
+
max_grad_norm: float, gradient norm clipping coefficient. If set to None, no clipping. (default: 10),
|
| 309 |
+
|
| 310 |
+
lr: float, learning rate for RMSProp (current implementation has RMSProp hardcoded in) (default: 7e-4)
|
| 311 |
+
|
| 312 |
+
lrschedule: schedule of learning rate. Can be 'linear', 'constant', or a function [0..1] -> [0..1] that takes fraction of the training progress as input and
|
| 313 |
+
returns fraction of the learning rate (specified as lr) as output
|
| 314 |
+
|
| 315 |
+
rprop_epsilon: float, RMSProp epsilon (stabilizes square root computation in denominator of RMSProp update) (default: 1e-5)
|
| 316 |
+
|
| 317 |
+
rprop_alpha: float, RMSProp decay parameter (default: 0.99)
|
| 318 |
+
|
| 319 |
+
gamma: float, reward discounting factor (default: 0.99)
|
| 320 |
+
|
| 321 |
+
log_interval: int, number of updates between logging events (default: 100)
|
| 322 |
+
|
| 323 |
+
buffer_size: int, size of the replay buffer (default: 50k)
|
| 324 |
+
|
| 325 |
+
replay_ratio: int, now many (on average) batches of data to sample from the replay buffer take after batch from the environment (default: 4)
|
| 326 |
+
|
| 327 |
+
replay_start: int, the sampling from the replay buffer does not start until replay buffer has at least that many samples (default: 10k)
|
| 328 |
+
|
| 329 |
+
c: float, importance weight clipping factor (default: 10)
|
| 330 |
+
|
| 331 |
+
trust_region bool, whether or not algorithms estimates the gradient KL divergence between the old and updated policy and uses it to determine step size (default: True)
|
| 332 |
+
|
| 333 |
+
delta: float, max KL divergence between the old policy and updated policy (default: 1)
|
| 334 |
+
|
| 335 |
+
alpha: float, momentum factor in the Polyak (exponential moving average) averaging of the model parameters (default: 0.99)
|
| 336 |
+
|
| 337 |
+
load_path: str, path to load the model from (default: None)
|
| 338 |
+
|
| 339 |
+
**network_kwargs: keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
|
| 340 |
+
For instance, 'mlp' network architecture has arguments num_hidden and num_layers.
|
| 341 |
+
|
| 342 |
+
'''
|
| 343 |
+
|
| 344 |
+
print("Running Acer Simple")
|
| 345 |
+
print(locals())
|
| 346 |
+
set_global_seeds(seed)
|
| 347 |
+
if not isinstance(env, VecFrameStack):
|
| 348 |
+
env = VecFrameStack(env, 1)
|
| 349 |
+
|
| 350 |
+
policy = build_policy(env, network, estimate_q=True, **network_kwargs)
|
| 351 |
+
nenvs = env.num_envs
|
| 352 |
+
ob_space = env.observation_space
|
| 353 |
+
ac_space = env.action_space
|
| 354 |
+
|
| 355 |
+
nstack = env.nstack
|
| 356 |
+
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps,
|
| 357 |
+
ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
|
| 358 |
+
max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
|
| 359 |
+
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
|
| 360 |
+
trust_region=trust_region, alpha=alpha, delta=delta)
|
| 361 |
+
|
| 362 |
+
if load_path is not None:
|
| 363 |
+
model.load(load_path)
|
| 364 |
+
|
| 365 |
+
runner = Runner(env=env, model=model, nsteps=nsteps)
|
| 366 |
+
if replay_ratio > 0:
|
| 367 |
+
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
|
| 368 |
+
else:
|
| 369 |
+
buffer = None
|
| 370 |
+
nbatch = nenvs*nsteps
|
| 371 |
+
acer = Acer(runner, model, buffer, log_interval)
|
| 372 |
+
acer.tstart = time.time()
|
| 373 |
+
|
| 374 |
+
for acer.steps in range(0, total_timesteps, nbatch): #nbatch samples, 1 on_policy call and multiple off-policy calls
|
| 375 |
+
acer.call(on_policy=True)
|
| 376 |
+
if replay_ratio > 0 and buffer.has_atleast(replay_start):
|
| 377 |
+
n = np.random.poisson(replay_ratio)
|
| 378 |
+
for _ in range(n):
|
| 379 |
+
acer.call(on_policy=False) # no simulation steps in this
|
| 380 |
+
|
| 381 |
+
return model
|
baselines/acer/buffer.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Buffer(object):
|
| 4 |
+
# gets obs, actions, rewards, mu's, (states, masks), dones
|
| 5 |
+
def __init__(self, env, nsteps, size=50000):
|
| 6 |
+
self.nenv = env.num_envs
|
| 7 |
+
self.nsteps = nsteps
|
| 8 |
+
# self.nh, self.nw, self.nc = env.observation_space.shape
|
| 9 |
+
self.obs_shape = env.observation_space.shape
|
| 10 |
+
self.obs_dtype = env.observation_space.dtype
|
| 11 |
+
self.ac_dtype = env.action_space.dtype
|
| 12 |
+
self.nc = self.obs_shape[-1]
|
| 13 |
+
self.nstack = env.nstack
|
| 14 |
+
self.nc //= self.nstack
|
| 15 |
+
self.nbatch = self.nenv * self.nsteps
|
| 16 |
+
self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
|
| 17 |
+
|
| 18 |
+
# Memory
|
| 19 |
+
self.enc_obs = None
|
| 20 |
+
self.actions = None
|
| 21 |
+
self.rewards = None
|
| 22 |
+
self.mus = None
|
| 23 |
+
self.dones = None
|
| 24 |
+
self.masks = None
|
| 25 |
+
|
| 26 |
+
# Size indexes
|
| 27 |
+
self.next_idx = 0
|
| 28 |
+
self.num_in_buffer = 0
|
| 29 |
+
|
| 30 |
+
def has_atleast(self, frames):
|
| 31 |
+
# Frames per env, so total (nenv * frames) Frames needed
|
| 32 |
+
# Each buffer loc has nenv * nsteps frames
|
| 33 |
+
return self.num_in_buffer >= (frames // self.nsteps)
|
| 34 |
+
|
| 35 |
+
def can_sample(self):
|
| 36 |
+
return self.num_in_buffer > 0
|
| 37 |
+
|
| 38 |
+
# Generate stacked frames
|
| 39 |
+
def decode(self, enc_obs, dones):
|
| 40 |
+
# enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
|
| 41 |
+
# dones has shape [nenvs, nsteps]
|
| 42 |
+
# returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
|
| 43 |
+
|
| 44 |
+
return _stack_obs(enc_obs, dones,
|
| 45 |
+
nsteps=self.nsteps)
|
| 46 |
+
|
| 47 |
+
def put(self, enc_obs, actions, rewards, mus, dones, masks):
|
| 48 |
+
# enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
|
| 49 |
+
# actions, rewards, dones [nenv, nsteps]
|
| 50 |
+
# mus [nenv, nsteps, nact]
|
| 51 |
+
|
| 52 |
+
if self.enc_obs is None:
|
| 53 |
+
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
|
| 54 |
+
self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype)
|
| 55 |
+
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
|
| 56 |
+
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
|
| 57 |
+
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
|
| 58 |
+
self.masks = np.empty([self.size] + list(masks.shape), dtype=np.bool)
|
| 59 |
+
|
| 60 |
+
self.enc_obs[self.next_idx] = enc_obs
|
| 61 |
+
self.actions[self.next_idx] = actions
|
| 62 |
+
self.rewards[self.next_idx] = rewards
|
| 63 |
+
self.mus[self.next_idx] = mus
|
| 64 |
+
self.dones[self.next_idx] = dones
|
| 65 |
+
self.masks[self.next_idx] = masks
|
| 66 |
+
|
| 67 |
+
self.next_idx = (self.next_idx + 1) % self.size
|
| 68 |
+
self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
|
| 69 |
+
|
| 70 |
+
def take(self, x, idx, envx):
|
| 71 |
+
nenv = self.nenv
|
| 72 |
+
out = np.empty([nenv] + list(x.shape[2:]), dtype=x.dtype)
|
| 73 |
+
for i in range(nenv):
|
| 74 |
+
out[i] = x[idx[i], envx[i]]
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
def get(self):
|
| 78 |
+
# returns
|
| 79 |
+
# obs [nenv, (nsteps + 1), nh, nw, nstack*nc]
|
| 80 |
+
# actions, rewards, dones [nenv, nsteps]
|
| 81 |
+
# mus [nenv, nsteps, nact]
|
| 82 |
+
nenv = self.nenv
|
| 83 |
+
assert self.can_sample()
|
| 84 |
+
|
| 85 |
+
# Sample exactly one id per env. If you sample across envs, then higher correlation in samples from same env.
|
| 86 |
+
idx = np.random.randint(0, self.num_in_buffer, nenv)
|
| 87 |
+
envx = np.arange(nenv)
|
| 88 |
+
|
| 89 |
+
take = lambda x: self.take(x, idx, envx) # for i in range(nenv)], axis = 0)
|
| 90 |
+
dones = take(self.dones)
|
| 91 |
+
enc_obs = take(self.enc_obs)
|
| 92 |
+
obs = self.decode(enc_obs, dones)
|
| 93 |
+
actions = take(self.actions)
|
| 94 |
+
rewards = take(self.rewards)
|
| 95 |
+
mus = take(self.mus)
|
| 96 |
+
masks = take(self.masks)
|
| 97 |
+
return obs, actions, rewards, mus, dones, masks
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _stack_obs_ref(enc_obs, dones, nsteps):
|
| 102 |
+
nenv = enc_obs.shape[0]
|
| 103 |
+
nstack = enc_obs.shape[1] - nsteps
|
| 104 |
+
nh, nw, nc = enc_obs.shape[2:]
|
| 105 |
+
obs_dtype = enc_obs.dtype
|
| 106 |
+
obs_shape = (nh, nw, nc*nstack)
|
| 107 |
+
|
| 108 |
+
mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
|
| 109 |
+
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype)
|
| 110 |
+
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc]
|
| 111 |
+
|
| 112 |
+
mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
|
| 113 |
+
mask[:nstack-1] = 1.0
|
| 114 |
+
|
| 115 |
+
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
|
| 116 |
+
for i in range(nstack):
|
| 117 |
+
obs[-(i + 1), i:] = x
|
| 118 |
+
# obs[:,i:,:,:,-(i+1),:] = x
|
| 119 |
+
x = x[:-1] * mask
|
| 120 |
+
mask = mask[1:]
|
| 121 |
+
|
| 122 |
+
return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape)
|
| 123 |
+
|
| 124 |
+
def _stack_obs(enc_obs, dones, nsteps):
|
| 125 |
+
nenv = enc_obs.shape[0]
|
| 126 |
+
nstack = enc_obs.shape[1] - nsteps
|
| 127 |
+
nc = enc_obs.shape[-1]
|
| 128 |
+
|
| 129 |
+
obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype)
|
| 130 |
+
mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype)
|
| 131 |
+
mask[:, 1:] = 1.0 - dones
|
| 132 |
+
mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8)))
|
| 133 |
+
|
| 134 |
+
for i in range(nstack-1, -1, -1):
|
| 135 |
+
obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :]
|
| 136 |
+
if i < nstack-1:
|
| 137 |
+
obs_[..., i * nc : (i + 1) * nc] *= mask
|
| 138 |
+
mask[:, 1:, ...] *= mask[:, :-1, ...]
|
| 139 |
+
|
| 140 |
+
return obs_
|
| 141 |
+
|
| 142 |
+
def test_stack_obs():
|
| 143 |
+
nstack = 7
|
| 144 |
+
nenv = 1
|
| 145 |
+
nsteps = 5
|
| 146 |
+
|
| 147 |
+
obs_shape = (2, 3, nstack)
|
| 148 |
+
|
| 149 |
+
enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,)
|
| 150 |
+
enc_obs = np.random.random(enc_obs_shape)
|
| 151 |
+
dones = np.random.randint(low=0, high=2, size=(nenv, nsteps))
|
| 152 |
+
|
| 153 |
+
stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps)
|
| 154 |
+
stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps)
|
| 155 |
+
|
| 156 |
+
np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test)
|
baselines/acer/defaults.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def atari():
|
| 2 |
+
return dict(
|
| 3 |
+
lrschedule='constant'
|
| 4 |
+
)
|
baselines/acer/policies.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
from baselines.common.policies import nature_cnn
|
| 4 |
+
from baselines.a2c.utils import fc, batch_to_seq, seq_to_batch, lstm, sample
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AcerCnnPolicy(object):
|
| 8 |
+
|
| 9 |
+
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False):
|
| 10 |
+
nbatch = nenv * nsteps
|
| 11 |
+
nh, nw, nc = ob_space.shape
|
| 12 |
+
ob_shape = (nbatch, nh, nw, nc * nstack)
|
| 13 |
+
nact = ac_space.n
|
| 14 |
+
X = tf.compat.v1.placeholder(tf.uint8, ob_shape) # obs
|
| 15 |
+
with tf.compat.v1.variable_scope("model", reuse=reuse):
|
| 16 |
+
h = nature_cnn(X)
|
| 17 |
+
pi_logits = fc(h, 'pi', nact, init_scale=0.01)
|
| 18 |
+
pi = tf.nn.softmax(pi_logits)
|
| 19 |
+
q = fc(h, 'q', nact)
|
| 20 |
+
|
| 21 |
+
a = sample(tf.nn.softmax(pi_logits)) # could change this to use self.pi instead
|
| 22 |
+
self.initial_state = [] # not stateful
|
| 23 |
+
self.X = X
|
| 24 |
+
self.pi = pi # actual policy params now
|
| 25 |
+
self.pi_logits = pi_logits
|
| 26 |
+
self.q = q
|
| 27 |
+
self.vf = q
|
| 28 |
+
|
| 29 |
+
def step(ob, *args, **kwargs):
|
| 30 |
+
# returns actions, mus, states
|
| 31 |
+
a0, pi0 = sess.run([a, pi], {X: ob})
|
| 32 |
+
return a0, pi0, [] # dummy state
|
| 33 |
+
|
| 34 |
+
def out(ob, *args, **kwargs):
|
| 35 |
+
pi0, q0 = sess.run([pi, q], {X: ob})
|
| 36 |
+
return pi0, q0
|
| 37 |
+
|
| 38 |
+
def act(ob, *args, **kwargs):
|
| 39 |
+
return sess.run(a, {X: ob})
|
| 40 |
+
|
| 41 |
+
self.step = step
|
| 42 |
+
self.out = out
|
| 43 |
+
self.act = act
|
| 44 |
+
|
| 45 |
+
class AcerLstmPolicy(object):
|
| 46 |
+
|
| 47 |
+
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False, nlstm=256):
|
| 48 |
+
nbatch = nenv * nsteps
|
| 49 |
+
nh, nw, nc = ob_space.shape
|
| 50 |
+
ob_shape = (nbatch, nh, nw, nc * nstack)
|
| 51 |
+
nact = ac_space.n
|
| 52 |
+
X = tf.compat.v1.placeholder(tf.uint8, ob_shape) # obs
|
| 53 |
+
M = tf.compat.v1.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
| 54 |
+
S = tf.compat.v1.placeholder(tf.float32, [nenv, nlstm*2]) #states
|
| 55 |
+
with tf.compat.v1.variable_scope("model", reuse=reuse):
|
| 56 |
+
h = nature_cnn(X)
|
| 57 |
+
|
| 58 |
+
# lstm
|
| 59 |
+
xs = batch_to_seq(h, nenv, nsteps)
|
| 60 |
+
ms = batch_to_seq(M, nenv, nsteps)
|
| 61 |
+
h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
|
| 62 |
+
h5 = seq_to_batch(h5)
|
| 63 |
+
|
| 64 |
+
pi_logits = fc(h5, 'pi', nact, init_scale=0.01)
|
| 65 |
+
pi = tf.nn.softmax(pi_logits)
|
| 66 |
+
q = fc(h5, 'q', nact)
|
| 67 |
+
|
| 68 |
+
a = sample(pi_logits) # could change this to use self.pi instead
|
| 69 |
+
self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
|
| 70 |
+
self.X = X
|
| 71 |
+
self.M = M
|
| 72 |
+
self.S = S
|
| 73 |
+
self.pi = pi # actual policy params now
|
| 74 |
+
self.q = q
|
| 75 |
+
|
| 76 |
+
def step(ob, state, mask, *args, **kwargs):
|
| 77 |
+
# returns actions, mus, states
|
| 78 |
+
a0, pi0, s = sess.run([a, pi, snew], {X: ob, S: state, M: mask})
|
| 79 |
+
return a0, pi0, s
|
| 80 |
+
|
| 81 |
+
self.step = step
|
baselines/acer/runner.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from baselines.common.runners import AbstractEnvRunner
|
| 3 |
+
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
| 4 |
+
from gym import spaces
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Runner(AbstractEnvRunner):
|
| 8 |
+
|
| 9 |
+
def __init__(self, env, model, nsteps):
|
| 10 |
+
super().__init__(env=env, model=model, nsteps=nsteps)
|
| 11 |
+
assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
|
| 12 |
+
assert isinstance(env, VecFrameStack)
|
| 13 |
+
|
| 14 |
+
self.nact = env.action_space.n
|
| 15 |
+
nenv = self.nenv
|
| 16 |
+
self.nbatch = nenv * nsteps
|
| 17 |
+
self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape
|
| 18 |
+
|
| 19 |
+
self.obs = env.reset()
|
| 20 |
+
self.obs_dtype = env.observation_space.dtype
|
| 21 |
+
self.ac_dtype = env.action_space.dtype
|
| 22 |
+
self.nstack = self.env.nstack
|
| 23 |
+
self.nc = self.batch_ob_shape[-1] // self.nstack
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run(self):
|
| 27 |
+
# enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
|
| 28 |
+
enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
|
| 29 |
+
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
|
| 30 |
+
for _ in range(self.nsteps):
|
| 31 |
+
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
|
| 32 |
+
mb_obs.append(np.copy(self.obs))
|
| 33 |
+
mb_actions.append(actions)
|
| 34 |
+
mb_mus.append(mus)
|
| 35 |
+
mb_dones.append(self.dones)
|
| 36 |
+
obs, rewards, dones, _ = self.env.step(actions)
|
| 37 |
+
# states information for statefull models like LSTM
|
| 38 |
+
self.states = states
|
| 39 |
+
self.dones = dones
|
| 40 |
+
self.obs = obs
|
| 41 |
+
mb_rewards.append(rewards)
|
| 42 |
+
enc_obs.append(obs[..., -self.nc:])
|
| 43 |
+
mb_obs.append(np.copy(self.obs))
|
| 44 |
+
mb_dones.append(self.dones)
|
| 45 |
+
|
| 46 |
+
enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
|
| 47 |
+
mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
|
| 48 |
+
mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)
|
| 49 |
+
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
| 50 |
+
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
|
| 51 |
+
|
| 52 |
+
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
|
| 53 |
+
|
| 54 |
+
mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
|
| 55 |
+
mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards
|
| 56 |
+
|
| 57 |
+
# shapes are now [nenv, nsteps, []]
|
| 58 |
+
# When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.
|
| 59 |
+
|
| 60 |
+
return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks
|
| 61 |
+
|
baselines/acktr/README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACKTR
|
| 2 |
+
|
| 3 |
+
- Original paper: https://arxiv.org/abs/1708.05144
|
| 4 |
+
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
|
| 5 |
+
- `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options.
|
| 6 |
+
- also refer to the repo-wide [README.md](../../README.md#training-models)
|
| 7 |
+
|
| 8 |
+
## ACKTR with continuous action spaces
|
| 9 |
+
The code of ACKTR has been refactored to handle both discrete and continuous action spaces uniformly. In the original version, discrete and continuous action spaces were handled by different code (actkr_disc.py and acktr_cont.py) with little overlap. If interested in the original version of the acktr for continuous action spaces, use `old_acktr_cont` branch. Note that original code performs better on the mujoco tasks than the refactored version; we are still investigating why.
|
baselines/acktr/__init__.py
ADDED
|
File without changes
|
baselines/acktr/acktr.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import time
|
| 3 |
+
import functools
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from baselines import logger
|
| 6 |
+
|
| 7 |
+
from baselines.common import set_global_seeds, explained_variance
|
| 8 |
+
from baselines.common.policies import build_policy
|
| 9 |
+
from baselines.common.tf_util import get_session, save_variables, load_variables
|
| 10 |
+
|
| 11 |
+
from baselines.a2c.runner import Runner
|
| 12 |
+
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
| 13 |
+
from baselines.acktr import kfac
|
| 14 |
+
from baselines.ppo2.ppo2 import safemean
|
| 15 |
+
from collections import deque
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Model(object):
|
| 19 |
+
|
| 20 |
+
def __init__(self, policy, ob_space, ac_space, nenvs,total_timesteps, nprocs=32, nsteps=20,
|
| 21 |
+
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
|
| 22 |
+
kfac_clip=0.001, lrschedule='linear', is_async=True):
|
| 23 |
+
|
| 24 |
+
self.sess = sess = get_session()
|
| 25 |
+
nbatch = nenvs * nsteps
|
| 26 |
+
with tf.compat.v1.variable_scope('acktr_model', reuse=tf.compat.v1.AUTO_REUSE):
|
| 27 |
+
self.model = step_model = policy(nenvs, 1, sess=sess)
|
| 28 |
+
self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess)
|
| 29 |
+
|
| 30 |
+
A = train_model.pdtype.sample_placeholder([None])
|
| 31 |
+
ADV = tf.compat.v1.placeholder(tf.float32, [nbatch])
|
| 32 |
+
R = tf.compat.v1.placeholder(tf.float32, [nbatch])
|
| 33 |
+
PG_LR = tf.compat.v1.placeholder(tf.float32, [])
|
| 34 |
+
VF_LR = tf.compat.v1.placeholder(tf.float32, [])
|
| 35 |
+
|
| 36 |
+
neglogpac = train_model.pd.neglogp(A)
|
| 37 |
+
self.logits = train_model.pi
|
| 38 |
+
|
| 39 |
+
##training loss
|
| 40 |
+
pg_loss = tf.reduce_mean(input_tensor=ADV*neglogpac)
|
| 41 |
+
entropy = tf.reduce_mean(input_tensor=train_model.pd.entropy())
|
| 42 |
+
pg_loss = pg_loss - ent_coef * entropy
|
| 43 |
+
vf_loss = tf.compat.v1.losses.mean_squared_error(tf.squeeze(train_model.vf), R)
|
| 44 |
+
train_loss = pg_loss + vf_coef * vf_loss
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
##Fisher loss construction
|
| 48 |
+
self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(input_tensor=neglogpac)
|
| 49 |
+
sample_net = train_model.vf + tf.random.normal(tf.shape(input=train_model.vf))
|
| 50 |
+
self.vf_fisher = vf_fisher_loss = - vf_fisher_coef*tf.reduce_mean(input_tensor=tf.pow(train_model.vf - tf.stop_gradient(sample_net), 2))
|
| 51 |
+
self.joint_fisher = joint_fisher_loss = pg_fisher_loss + vf_fisher_loss
|
| 52 |
+
|
| 53 |
+
self.params=params = find_trainable_variables("acktr_model")
|
| 54 |
+
|
| 55 |
+
self.grads_check = grads = tf.gradients(ys=train_loss,xs=params)
|
| 56 |
+
|
| 57 |
+
with tf.device('/gpu:0'):
|
| 58 |
+
self.optim = optim = kfac.KfacOptimizer(learning_rate=PG_LR, clip_kl=kfac_clip,\
|
| 59 |
+
momentum=0.9, kfac_update=1, epsilon=0.01,\
|
| 60 |
+
stats_decay=0.99, is_async=is_async, cold_iter=10, max_grad_norm=max_grad_norm)
|
| 61 |
+
|
| 62 |
+
# update_stats_op = optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
|
| 63 |
+
optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
|
| 64 |
+
train_op, q_runner = optim.apply_gradients(list(zip(grads,params)))
|
| 65 |
+
self.q_runner = q_runner
|
| 66 |
+
self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
|
| 67 |
+
|
| 68 |
+
def train(obs, states, rewards, masks, actions, values):
|
| 69 |
+
advs = rewards - values
|
| 70 |
+
for step in range(len(obs)):
|
| 71 |
+
cur_lr = self.lr.value()
|
| 72 |
+
|
| 73 |
+
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, PG_LR:cur_lr, VF_LR:cur_lr}
|
| 74 |
+
if states is not None:
|
| 75 |
+
td_map[train_model.S] = states
|
| 76 |
+
td_map[train_model.M] = masks
|
| 77 |
+
|
| 78 |
+
policy_loss, value_loss, policy_entropy, _ = sess.run(
|
| 79 |
+
[pg_loss, vf_loss, entropy, train_op],
|
| 80 |
+
td_map
|
| 81 |
+
)
|
| 82 |
+
return policy_loss, value_loss, policy_entropy
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
self.train = train
|
| 86 |
+
self.save = functools.partial(save_variables, sess=sess)
|
| 87 |
+
self.load = functools.partial(load_variables, sess=sess)
|
| 88 |
+
self.train_model = train_model
|
| 89 |
+
self.step_model = step_model
|
| 90 |
+
self.step = step_model.step
|
| 91 |
+
self.value = step_model.value
|
| 92 |
+
self.initial_state = step_model.initial_state
|
| 93 |
+
tf.compat.v1.global_variables_initializer().run(session=sess)
|
| 94 |
+
|
| 95 |
+
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=100, nprocs=32, nsteps=20,
|
| 96 |
+
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
|
| 97 |
+
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
|
| 98 |
+
set_global_seeds(seed)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if network == 'cnn':
|
| 102 |
+
network_kwargs['one_dim_bias'] = True
|
| 103 |
+
|
| 104 |
+
policy = build_policy(env, network, **network_kwargs)
|
| 105 |
+
|
| 106 |
+
nenvs = env.num_envs
|
| 107 |
+
ob_space = env.observation_space
|
| 108 |
+
ac_space = env.action_space
|
| 109 |
+
make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps
|
| 110 |
+
=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef=
|
| 111 |
+
vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip,
|
| 112 |
+
lrschedule=lrschedule, is_async=is_async)
|
| 113 |
+
if save_interval and logger.get_dir():
|
| 114 |
+
import cloudpickle
|
| 115 |
+
with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
|
| 116 |
+
fh.write(cloudpickle.dumps(make_model))
|
| 117 |
+
model = make_model()
|
| 118 |
+
|
| 119 |
+
if load_path is not None:
|
| 120 |
+
model.load(load_path)
|
| 121 |
+
|
| 122 |
+
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
| 123 |
+
epinfobuf = deque(maxlen=100)
|
| 124 |
+
nbatch = nenvs*nsteps
|
| 125 |
+
tstart = time.time()
|
| 126 |
+
coord = tf.train.Coordinator()
|
| 127 |
+
if is_async:
|
| 128 |
+
enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True)
|
| 129 |
+
else:
|
| 130 |
+
enqueue_threads = []
|
| 131 |
+
|
| 132 |
+
for update in range(1, total_timesteps//nbatch+1):
|
| 133 |
+
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
| 134 |
+
epinfobuf.extend(epinfos)
|
| 135 |
+
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
| 136 |
+
model.old_obs = obs
|
| 137 |
+
nseconds = time.time()-tstart
|
| 138 |
+
fps = int((update*nbatch)/nseconds)
|
| 139 |
+
if update % log_interval == 0 or update == 1:
|
| 140 |
+
ev = explained_variance(values, rewards)
|
| 141 |
+
logger.record_tabular("nupdates", update)
|
| 142 |
+
logger.record_tabular("total_timesteps", update*nbatch)
|
| 143 |
+
logger.record_tabular("fps", fps)
|
| 144 |
+
logger.record_tabular("policy_entropy", float(policy_entropy))
|
| 145 |
+
logger.record_tabular("policy_loss", float(policy_loss))
|
| 146 |
+
logger.record_tabular("value_loss", float(value_loss))
|
| 147 |
+
logger.record_tabular("explained_variance", float(ev))
|
| 148 |
+
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
| 149 |
+
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
| 150 |
+
logger.dump_tabular()
|
| 151 |
+
|
| 152 |
+
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
|
| 153 |
+
savepath = osp.join(logger.get_dir(), 'checkpoint%.5i'%update)
|
| 154 |
+
print('Saving to', savepath)
|
| 155 |
+
model.save(savepath)
|
| 156 |
+
coord.request_stop()
|
| 157 |
+
coord.join(enqueue_threads)
|
| 158 |
+
return model
|
baselines/acktr/defaults.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def mujoco():
|
| 2 |
+
return dict(
|
| 3 |
+
nsteps=2500,
|
| 4 |
+
value_network='copy'
|
| 5 |
+
)
|
baselines/acktr/kfac.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
# flake8: noqa F403, F405
|
| 6 |
+
from baselines.acktr.kfac_utils import *
|
| 7 |
+
from functools import reduce
|
| 8 |
+
|
| 9 |
+
KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd']
|
| 10 |
+
KFAC_DEBUG = False
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class KfacOptimizer():
|
| 14 |
+
# note that KfacOptimizer will be truly synchronous (and thus deterministic) only if a single-threaded session is used
|
| 15 |
+
def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, is_async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
|
| 16 |
+
self.max_grad_norm = max_grad_norm
|
| 17 |
+
self._lr = learning_rate
|
| 18 |
+
self._momentum = momentum
|
| 19 |
+
self._clip_kl = clip_kl
|
| 20 |
+
self._channel_fac = channel_fac
|
| 21 |
+
self._kfac_update = kfac_update
|
| 22 |
+
self._async = is_async
|
| 23 |
+
self._async_stats = async_stats
|
| 24 |
+
self._epsilon = epsilon
|
| 25 |
+
self._stats_decay = stats_decay
|
| 26 |
+
self._blockdiag_bias = blockdiag_bias
|
| 27 |
+
self._approxT2 = approxT2
|
| 28 |
+
self._use_float64 = use_float64
|
| 29 |
+
self._factored_damping = factored_damping
|
| 30 |
+
self._cold_iter = cold_iter
|
| 31 |
+
if cold_lr == None:
|
| 32 |
+
# good heuristics
|
| 33 |
+
self._cold_lr = self._lr# * 3.
|
| 34 |
+
else:
|
| 35 |
+
self._cold_lr = cold_lr
|
| 36 |
+
self._stats_accum_iter = stats_accum_iter
|
| 37 |
+
self._weight_decay_dict = weight_decay_dict
|
| 38 |
+
self._diag_init_coeff = 0.
|
| 39 |
+
self._full_stats_init = full_stats_init
|
| 40 |
+
if not self._full_stats_init:
|
| 41 |
+
self._stats_accum_iter = self._cold_iter
|
| 42 |
+
|
| 43 |
+
self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False)
|
| 44 |
+
self.global_step = tf.Variable(
|
| 45 |
+
0, name='KFAC/global_step', trainable=False)
|
| 46 |
+
self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False)
|
| 47 |
+
self.factor_step = tf.Variable(
|
| 48 |
+
0, name='KFAC/factor_step', trainable=False)
|
| 49 |
+
self.stats_step = tf.Variable(
|
| 50 |
+
0, name='KFAC/stats_step', trainable=False)
|
| 51 |
+
self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False)
|
| 52 |
+
|
| 53 |
+
self.factors = {}
|
| 54 |
+
self.param_vars = []
|
| 55 |
+
self.stats = {}
|
| 56 |
+
self.stats_eigen = {}
|
| 57 |
+
|
| 58 |
+
def getFactors(self, g, varlist):
|
| 59 |
+
graph = tf.compat.v1.get_default_graph()
|
| 60 |
+
factorTensors = {}
|
| 61 |
+
fpropTensors = []
|
| 62 |
+
bpropTensors = []
|
| 63 |
+
opTypes = []
|
| 64 |
+
fops = []
|
| 65 |
+
|
| 66 |
+
def searchFactors(gradient, graph):
|
| 67 |
+
# hard coded search stratergy
|
| 68 |
+
bpropOp = gradient.op
|
| 69 |
+
bpropOp_name = bpropOp.name
|
| 70 |
+
|
| 71 |
+
bTensors = []
|
| 72 |
+
fTensors = []
|
| 73 |
+
|
| 74 |
+
# combining additive gradient, assume they are the same op type and
|
| 75 |
+
# indepedent
|
| 76 |
+
if 'AddN' in bpropOp_name:
|
| 77 |
+
factors = []
|
| 78 |
+
for g in gradient.op.inputs:
|
| 79 |
+
factors.append(searchFactors(g, graph))
|
| 80 |
+
op_names = [item['opName'] for item in factors]
|
| 81 |
+
# TO-DO: need to check all the attribute of the ops as well
|
| 82 |
+
print (gradient.name)
|
| 83 |
+
print (op_names)
|
| 84 |
+
print (len(np.unique(op_names)))
|
| 85 |
+
assert len(np.unique(op_names)) == 1, gradient.name + \
|
| 86 |
+
' is shared among different computation OPs'
|
| 87 |
+
|
| 88 |
+
bTensors = reduce(lambda x, y: x + y,
|
| 89 |
+
[item['bpropFactors'] for item in factors])
|
| 90 |
+
if len(factors[0]['fpropFactors']) > 0:
|
| 91 |
+
fTensors = reduce(
|
| 92 |
+
lambda x, y: x + y, [item['fpropFactors'] for item in factors])
|
| 93 |
+
fpropOp_name = op_names[0]
|
| 94 |
+
fpropOp = factors[0]['op']
|
| 95 |
+
else:
|
| 96 |
+
fpropOp_name = re.search(
|
| 97 |
+
'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
|
| 98 |
+
fpropOp = graph.get_operation_by_name(fpropOp_name)
|
| 99 |
+
if fpropOp.op_def.name in KFAC_OPS:
|
| 100 |
+
# Known OPs
|
| 101 |
+
###
|
| 102 |
+
bTensor = [
|
| 103 |
+
i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
|
| 104 |
+
bTensorShape = fpropOp.outputs[0].get_shape()
|
| 105 |
+
if bTensor.get_shape()[0].value == None:
|
| 106 |
+
bTensor.set_shape(bTensorShape)
|
| 107 |
+
bTensors.append(bTensor)
|
| 108 |
+
###
|
| 109 |
+
if fpropOp.op_def.name == 'BiasAdd':
|
| 110 |
+
fTensors = []
|
| 111 |
+
else:
|
| 112 |
+
fTensors.append(
|
| 113 |
+
[i for i in fpropOp.inputs if param.op.name not in i.name][0])
|
| 114 |
+
fpropOp_name = fpropOp.op_def.name
|
| 115 |
+
else:
|
| 116 |
+
# unknown OPs, block approximation used
|
| 117 |
+
bInputsList = [i for i in bpropOp.inputs[
|
| 118 |
+
0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
|
| 119 |
+
if len(bInputsList) > 0:
|
| 120 |
+
bTensor = bInputsList[0]
|
| 121 |
+
bTensorShape = fpropOp.outputs[0].get_shape()
|
| 122 |
+
if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
|
| 123 |
+
bTensor.set_shape(bTensorShape)
|
| 124 |
+
bTensors.append(bTensor)
|
| 125 |
+
fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)
|
| 126 |
+
|
| 127 |
+
return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}
|
| 128 |
+
|
| 129 |
+
for t, param in zip(g, varlist):
|
| 130 |
+
if KFAC_DEBUG:
|
| 131 |
+
print(('get factor for '+param.name))
|
| 132 |
+
factors = searchFactors(t, graph)
|
| 133 |
+
factorTensors[param] = factors
|
| 134 |
+
|
| 135 |
+
########
|
| 136 |
+
# check associated weights and bias for homogeneous coordinate representation
|
| 137 |
+
# and check redundent factors
|
| 138 |
+
# TO-DO: there may be a bug to detect associate bias and weights for
|
| 139 |
+
# forking layer, e.g. in inception models.
|
| 140 |
+
for param in varlist:
|
| 141 |
+
factorTensors[param]['assnWeights'] = None
|
| 142 |
+
factorTensors[param]['assnBias'] = None
|
| 143 |
+
for param in varlist:
|
| 144 |
+
if factorTensors[param]['opName'] == 'BiasAdd':
|
| 145 |
+
factorTensors[param]['assnWeights'] = None
|
| 146 |
+
for item in varlist:
|
| 147 |
+
if len(factorTensors[item]['bpropFactors']) > 0:
|
| 148 |
+
if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
|
| 149 |
+
factorTensors[param]['assnWeights'] = item
|
| 150 |
+
factorTensors[item]['assnBias'] = param
|
| 151 |
+
factorTensors[param]['bpropFactors'] = factorTensors[
|
| 152 |
+
item]['bpropFactors']
|
| 153 |
+
|
| 154 |
+
########
|
| 155 |
+
|
| 156 |
+
########
|
| 157 |
+
# concatenate the additive gradients along the batch dimension, i.e.
|
| 158 |
+
# assuming independence structure
|
| 159 |
+
for key in ['fpropFactors', 'bpropFactors']:
|
| 160 |
+
for i, param in enumerate(varlist):
|
| 161 |
+
if len(factorTensors[param][key]) > 0:
|
| 162 |
+
if (key + '_concat') not in factorTensors[param]:
|
| 163 |
+
name_scope = factorTensors[param][key][0].name.split(':')[
|
| 164 |
+
0]
|
| 165 |
+
with tf.compat.v1.name_scope(name_scope):
|
| 166 |
+
factorTensors[param][
|
| 167 |
+
key + '_concat'] = tf.concat(factorTensors[param][key], 0)
|
| 168 |
+
else:
|
| 169 |
+
factorTensors[param][key + '_concat'] = None
|
| 170 |
+
for j, param2 in enumerate(varlist[(i + 1):]):
|
| 171 |
+
if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
|
| 172 |
+
factorTensors[param2][key] = factorTensors[param][key]
|
| 173 |
+
factorTensors[param2][
|
| 174 |
+
key + '_concat'] = factorTensors[param][key + '_concat']
|
| 175 |
+
########
|
| 176 |
+
|
| 177 |
+
if KFAC_DEBUG:
|
| 178 |
+
for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
|
| 179 |
+
print((items[0].name, factorTensors[item]))
|
| 180 |
+
self.factors = factorTensors
|
| 181 |
+
return factorTensors
|
| 182 |
+
|
| 183 |
+
def getStats(self, factors, varlist):
|
| 184 |
+
if len(self.stats) == 0:
|
| 185 |
+
# initialize stats variables on CPU because eigen decomp is
|
| 186 |
+
# computed on CPU
|
| 187 |
+
with tf.device('/cpu'):
|
| 188 |
+
tmpStatsCache = {}
|
| 189 |
+
|
| 190 |
+
# search for tensor factors and
|
| 191 |
+
# use block diag approx for the bias units
|
| 192 |
+
for var in varlist:
|
| 193 |
+
fpropFactor = factors[var]['fpropFactors_concat']
|
| 194 |
+
bpropFactor = factors[var]['bpropFactors_concat']
|
| 195 |
+
opType = factors[var]['opName']
|
| 196 |
+
if opType == 'Conv2D':
|
| 197 |
+
Kh = var.get_shape()[0]
|
| 198 |
+
Kw = var.get_shape()[1]
|
| 199 |
+
C = fpropFactor.get_shape()[-1]
|
| 200 |
+
|
| 201 |
+
Oh = bpropFactor.get_shape()[1]
|
| 202 |
+
Ow = bpropFactor.get_shape()[2]
|
| 203 |
+
if Oh == 1 and Ow == 1 and self._channel_fac:
|
| 204 |
+
# factorization along the channels do not support
|
| 205 |
+
# homogeneous coordinate
|
| 206 |
+
var_assnBias = factors[var]['assnBias']
|
| 207 |
+
if var_assnBias:
|
| 208 |
+
factors[var]['assnBias'] = None
|
| 209 |
+
factors[var_assnBias]['assnWeights'] = None
|
| 210 |
+
##
|
| 211 |
+
|
| 212 |
+
for var in varlist:
|
| 213 |
+
fpropFactor = factors[var]['fpropFactors_concat']
|
| 214 |
+
bpropFactor = factors[var]['bpropFactors_concat']
|
| 215 |
+
opType = factors[var]['opName']
|
| 216 |
+
self.stats[var] = {'opName': opType,
|
| 217 |
+
'fprop_concat_stats': [],
|
| 218 |
+
'bprop_concat_stats': [],
|
| 219 |
+
'assnWeights': factors[var]['assnWeights'],
|
| 220 |
+
'assnBias': factors[var]['assnBias'],
|
| 221 |
+
}
|
| 222 |
+
if fpropFactor is not None:
|
| 223 |
+
if fpropFactor not in tmpStatsCache:
|
| 224 |
+
if opType == 'Conv2D':
|
| 225 |
+
Kh = var.get_shape()[0]
|
| 226 |
+
Kw = var.get_shape()[1]
|
| 227 |
+
C = fpropFactor.get_shape()[-1]
|
| 228 |
+
|
| 229 |
+
Oh = bpropFactor.get_shape()[1]
|
| 230 |
+
Ow = bpropFactor.get_shape()[2]
|
| 231 |
+
if Oh == 1 and Ow == 1 and self._channel_fac:
|
| 232 |
+
# factorization along the channels
|
| 233 |
+
# assume independence between input channels and spatial
|
| 234 |
+
# 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
|
| 235 |
+
# factorization along the channels do not
|
| 236 |
+
# support homogeneous coordinate, assnBias
|
| 237 |
+
# is always None
|
| 238 |
+
fpropFactor2_size = Kh * Kw
|
| 239 |
+
slot_fpropFactor_stats2 = tf.Variable(tf.linalg.tensor_diag(tf.ones(
|
| 240 |
+
[fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
|
| 241 |
+
self.stats[var]['fprop_concat_stats'].append(
|
| 242 |
+
slot_fpropFactor_stats2)
|
| 243 |
+
|
| 244 |
+
fpropFactor_size = C
|
| 245 |
+
else:
|
| 246 |
+
# 2K-1 x 2K-1 x C x C covariance matrix
|
| 247 |
+
# assume BHWC
|
| 248 |
+
fpropFactor_size = Kh * Kw * C
|
| 249 |
+
else:
|
| 250 |
+
# D x D covariance matrix
|
| 251 |
+
fpropFactor_size = fpropFactor.get_shape()[-1]
|
| 252 |
+
|
| 253 |
+
# use homogeneous coordinate
|
| 254 |
+
if not self._blockdiag_bias and self.stats[var]['assnBias']:
|
| 255 |
+
fpropFactor_size += 1
|
| 256 |
+
|
| 257 |
+
slot_fpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones(
|
| 258 |
+
[fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
|
| 259 |
+
self.stats[var]['fprop_concat_stats'].append(
|
| 260 |
+
slot_fpropFactor_stats)
|
| 261 |
+
if opType != 'Conv2D':
|
| 262 |
+
tmpStatsCache[fpropFactor] = self.stats[
|
| 263 |
+
var]['fprop_concat_stats']
|
| 264 |
+
else:
|
| 265 |
+
self.stats[var][
|
| 266 |
+
'fprop_concat_stats'] = tmpStatsCache[fpropFactor]
|
| 267 |
+
|
| 268 |
+
if bpropFactor is not None:
|
| 269 |
+
# no need to collect backward stats for bias vectors if
|
| 270 |
+
# using homogeneous coordinates
|
| 271 |
+
if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
|
| 272 |
+
if bpropFactor not in tmpStatsCache:
|
| 273 |
+
slot_bpropFactor_stats = tf.Variable(tf.linalg.tensor_diag(tf.ones([bpropFactor.get_shape(
|
| 274 |
+
)[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
|
| 275 |
+
self.stats[var]['bprop_concat_stats'].append(
|
| 276 |
+
slot_bpropFactor_stats)
|
| 277 |
+
tmpStatsCache[bpropFactor] = self.stats[
|
| 278 |
+
var]['bprop_concat_stats']
|
| 279 |
+
else:
|
| 280 |
+
self.stats[var][
|
| 281 |
+
'bprop_concat_stats'] = tmpStatsCache[bpropFactor]
|
| 282 |
+
|
| 283 |
+
return self.stats
|
| 284 |
+
|
| 285 |
+
def compute_and_apply_stats(self, loss_sampled, var_list=None):
|
| 286 |
+
varlist = var_list
|
| 287 |
+
if varlist is None:
|
| 288 |
+
varlist = tf.compat.v1.trainable_variables()
|
| 289 |
+
|
| 290 |
+
stats = self.compute_stats(loss_sampled, var_list=varlist)
|
| 291 |
+
return self.apply_stats(stats)
|
| 292 |
+
|
| 293 |
+
def compute_stats(self, loss_sampled, var_list=None):
|
| 294 |
+
varlist = var_list
|
| 295 |
+
if varlist is None:
|
| 296 |
+
varlist = tf.compat.v1.trainable_variables()
|
| 297 |
+
|
| 298 |
+
gs = tf.gradients(ys=loss_sampled, xs=varlist, name='gradientsSampled')
|
| 299 |
+
self.gs = gs
|
| 300 |
+
factors = self.getFactors(gs, varlist)
|
| 301 |
+
stats = self.getStats(factors, varlist)
|
| 302 |
+
|
| 303 |
+
updateOps = []
|
| 304 |
+
statsUpdates = {}
|
| 305 |
+
statsUpdates_cache = {}
|
| 306 |
+
for var in varlist:
|
| 307 |
+
opType = factors[var]['opName']
|
| 308 |
+
fops = factors[var]['op']
|
| 309 |
+
fpropFactor = factors[var]['fpropFactors_concat']
|
| 310 |
+
fpropStats_vars = stats[var]['fprop_concat_stats']
|
| 311 |
+
bpropFactor = factors[var]['bpropFactors_concat']
|
| 312 |
+
bpropStats_vars = stats[var]['bprop_concat_stats']
|
| 313 |
+
SVD_factors = {}
|
| 314 |
+
for stats_var in fpropStats_vars:
|
| 315 |
+
stats_var_dim = int(stats_var.get_shape()[0])
|
| 316 |
+
if stats_var not in statsUpdates_cache:
|
| 317 |
+
old_fpropFactor = fpropFactor
|
| 318 |
+
B = (tf.shape(input=fpropFactor)[0]) # batch size
|
| 319 |
+
if opType == 'Conv2D':
|
| 320 |
+
strides = fops.get_attr("strides")
|
| 321 |
+
padding = fops.get_attr("padding")
|
| 322 |
+
convkernel_size = var.get_shape()[0:3]
|
| 323 |
+
|
| 324 |
+
KH = int(convkernel_size[0])
|
| 325 |
+
KW = int(convkernel_size[1])
|
| 326 |
+
C = int(convkernel_size[2])
|
| 327 |
+
flatten_size = int(KH * KW * C)
|
| 328 |
+
|
| 329 |
+
Oh = int(bpropFactor.get_shape()[1])
|
| 330 |
+
Ow = int(bpropFactor.get_shape()[2])
|
| 331 |
+
|
| 332 |
+
if Oh == 1 and Ow == 1 and self._channel_fac:
|
| 333 |
+
# factorization along the channels
|
| 334 |
+
# assume independence among input channels
|
| 335 |
+
# factor = B x 1 x 1 x (KH xKW x C)
|
| 336 |
+
# patches = B x Oh x Ow x (KH xKW x C)
|
| 337 |
+
if len(SVD_factors) == 0:
|
| 338 |
+
if KFAC_DEBUG:
|
| 339 |
+
print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
|
| 340 |
+
# find closest rank-1 approx to the feature map
|
| 341 |
+
S, U, V = tf.batch_svd(tf.reshape(
|
| 342 |
+
fpropFactor, [-1, KH * KW, C]))
|
| 343 |
+
# get rank-1 approx slides
|
| 344 |
+
sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
|
| 345 |
+
patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW
|
| 346 |
+
full_factor_shape = fpropFactor.get_shape()
|
| 347 |
+
patches_k.set_shape(
|
| 348 |
+
[full_factor_shape[0], KH * KW])
|
| 349 |
+
patches_c = V[:, :, 0] * sqrtS1 # B x C
|
| 350 |
+
patches_c.set_shape([full_factor_shape[0], C])
|
| 351 |
+
SVD_factors[C] = patches_c
|
| 352 |
+
SVD_factors[KH * KW] = patches_k
|
| 353 |
+
fpropFactor = SVD_factors[stats_var_dim]
|
| 354 |
+
|
| 355 |
+
else:
|
| 356 |
+
# poor mem usage implementation
|
| 357 |
+
patches = tf.image.extract_patches(fpropFactor, sizes=[1, convkernel_size[
|
| 358 |
+
0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)
|
| 359 |
+
|
| 360 |
+
if self._approxT2:
|
| 361 |
+
if KFAC_DEBUG:
|
| 362 |
+
print(('approxT2 act fisher for %s' % (var.name)))
|
| 363 |
+
# T^2 terms * 1/T^2, size: B x C
|
| 364 |
+
fpropFactor = tf.reduce_mean(input_tensor=patches, axis=[1, 2])
|
| 365 |
+
else:
|
| 366 |
+
# size: (B x Oh x Ow) x C
|
| 367 |
+
fpropFactor = tf.reshape(
|
| 368 |
+
patches, [-1, flatten_size]) / Oh / Ow
|
| 369 |
+
fpropFactor_size = int(fpropFactor.get_shape()[-1])
|
| 370 |
+
if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
|
| 371 |
+
if opType == 'Conv2D' and not self._approxT2:
|
| 372 |
+
# correct padding for numerical stability (we
|
| 373 |
+
# divided out OhxOw from activations for T1 approx)
|
| 374 |
+
fpropFactor = tf.concat([fpropFactor, tf.ones(
|
| 375 |
+
[tf.shape(input=fpropFactor)[0], 1]) / Oh / Ow], 1)
|
| 376 |
+
else:
|
| 377 |
+
# use homogeneous coordinates
|
| 378 |
+
fpropFactor = tf.concat(
|
| 379 |
+
[fpropFactor, tf.ones([tf.shape(input=fpropFactor)[0], 1])], 1)
|
| 380 |
+
|
| 381 |
+
# average over the number of data points in a batch
|
| 382 |
+
# divided by B
|
| 383 |
+
cov = tf.matmul(fpropFactor, fpropFactor,
|
| 384 |
+
transpose_a=True) / tf.cast(B, tf.float32)
|
| 385 |
+
updateOps.append(cov)
|
| 386 |
+
statsUpdates[stats_var] = cov
|
| 387 |
+
if opType != 'Conv2D':
|
| 388 |
+
# HACK: for convolution we recompute fprop stats for
|
| 389 |
+
# every layer including forking layers
|
| 390 |
+
statsUpdates_cache[stats_var] = cov
|
| 391 |
+
|
| 392 |
+
for stats_var in bpropStats_vars:
|
| 393 |
+
stats_var_dim = int(stats_var.get_shape()[0])
|
| 394 |
+
if stats_var not in statsUpdates_cache:
|
| 395 |
+
old_bpropFactor = bpropFactor
|
| 396 |
+
bpropFactor_shape = bpropFactor.get_shape()
|
| 397 |
+
B = tf.shape(input=bpropFactor)[0] # batch size
|
| 398 |
+
C = int(bpropFactor_shape[-1]) # num channels
|
| 399 |
+
if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
|
| 400 |
+
if fpropFactor is not None:
|
| 401 |
+
if self._approxT2:
|
| 402 |
+
if KFAC_DEBUG:
|
| 403 |
+
print(('approxT2 grad fisher for %s' % (var.name)))
|
| 404 |
+
bpropFactor = tf.reduce_sum(
|
| 405 |
+
input_tensor=bpropFactor, axis=[1, 2]) # T^2 terms * 1/T^2
|
| 406 |
+
else:
|
| 407 |
+
bpropFactor = tf.reshape(
|
| 408 |
+
bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms
|
| 409 |
+
else:
|
| 410 |
+
# just doing block diag approx. spatial independent
|
| 411 |
+
# structure does not apply here. summing over
|
| 412 |
+
# spatial locations
|
| 413 |
+
if KFAC_DEBUG:
|
| 414 |
+
print(('block diag approx fisher for %s' % (var.name)))
|
| 415 |
+
bpropFactor = tf.reduce_sum(input_tensor=bpropFactor, axis=[1, 2])
|
| 416 |
+
|
| 417 |
+
# assume sampled loss is averaged. TO-DO:figure out better
|
| 418 |
+
# way to handle this
|
| 419 |
+
bpropFactor *= tf.cast(B, dtype=tf.float32)
|
| 420 |
+
##
|
| 421 |
+
|
| 422 |
+
cov_b = tf.matmul(
|
| 423 |
+
bpropFactor, bpropFactor, transpose_a=True) / tf.cast(tf.shape(input=bpropFactor)[0], dtype=tf.float32)
|
| 424 |
+
|
| 425 |
+
updateOps.append(cov_b)
|
| 426 |
+
statsUpdates[stats_var] = cov_b
|
| 427 |
+
statsUpdates_cache[stats_var] = cov_b
|
| 428 |
+
|
| 429 |
+
if KFAC_DEBUG:
|
| 430 |
+
aKey = list(statsUpdates.keys())[0]
|
| 431 |
+
statsUpdates[aKey] = tf.compat.v1.Print(statsUpdates[aKey],
|
| 432 |
+
[tf.convert_to_tensor(value='step:'),
|
| 433 |
+
self.global_step,
|
| 434 |
+
tf.convert_to_tensor(
|
| 435 |
+
value='computing stats'),
|
| 436 |
+
])
|
| 437 |
+
self.statsUpdates = statsUpdates
|
| 438 |
+
return statsUpdates
|
| 439 |
+
|
| 440 |
+
def apply_stats(self, statsUpdates):
|
| 441 |
+
""" compute stats and update/apply the new stats to the running average
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
def updateAccumStats():
|
| 445 |
+
if self._full_stats_init:
|
| 446 |
+
return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), false_fn=tf.no_op)
|
| 447 |
+
else:
|
| 448 |
+
return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))
|
| 449 |
+
|
| 450 |
+
def updateRunningAvgStats(statsUpdates, fac_iter=1):
|
| 451 |
+
# return tf.cond(tf.greater_equal(self.factor_step,
|
| 452 |
+
# tf.convert_to_tensor(fac_iter)), lambda:
|
| 453 |
+
# tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op)
|
| 454 |
+
return tf.group(*self._apply_stats(statsUpdates))
|
| 455 |
+
|
| 456 |
+
if self._async_stats:
|
| 457 |
+
# asynchronous stats update
|
| 458 |
+
update_stats = self._apply_stats(statsUpdates)
|
| 459 |
+
|
| 460 |
+
queue = tf.queue.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[
|
| 461 |
+
item.get_shape() for item in update_stats])
|
| 462 |
+
enqueue_op = queue.enqueue(update_stats)
|
| 463 |
+
|
| 464 |
+
def dequeue_stats_op():
|
| 465 |
+
return queue.dequeue()
|
| 466 |
+
self.qr_stats = tf.compat.v1.train.QueueRunner(queue, [enqueue_op])
|
| 467 |
+
update_stats_op = tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor(
|
| 468 |
+
value=0)), true_fn=tf.no_op, false_fn=lambda: tf.group(*[dequeue_stats_op(), ]))
|
| 469 |
+
else:
|
| 470 |
+
# synchronous stats update
|
| 471 |
+
update_stats_op = tf.cond(pred=tf.greater_equal(
|
| 472 |
+
self.stats_step, self._stats_accum_iter), true_fn=lambda: updateRunningAvgStats(statsUpdates), false_fn=updateAccumStats)
|
| 473 |
+
self._update_stats_op = update_stats_op
|
| 474 |
+
return update_stats_op
|
| 475 |
+
|
| 476 |
+
def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.):
|
| 477 |
+
updateOps = []
|
| 478 |
+
# obtain the stats var list
|
| 479 |
+
for stats_var in statsUpdates:
|
| 480 |
+
stats_new = statsUpdates[stats_var]
|
| 481 |
+
if accumulate:
|
| 482 |
+
# simple superbatch averaging
|
| 483 |
+
update_op = tf.compat.v1.assign_add(
|
| 484 |
+
stats_var, accumulateCoeff * stats_new, use_locking=True)
|
| 485 |
+
else:
|
| 486 |
+
# exponential running averaging
|
| 487 |
+
update_op = tf.compat.v1.assign(
|
| 488 |
+
stats_var, stats_var * self._stats_decay, use_locking=True)
|
| 489 |
+
update_op = tf.compat.v1.assign_add(
|
| 490 |
+
update_op, (1. - self._stats_decay) * stats_new, use_locking=True)
|
| 491 |
+
updateOps.append(update_op)
|
| 492 |
+
|
| 493 |
+
with tf.control_dependencies(updateOps):
|
| 494 |
+
stats_step_op = tf.compat.v1.assign_add(self.stats_step, 1)
|
| 495 |
+
|
| 496 |
+
if KFAC_DEBUG:
|
| 497 |
+
stats_step_op = (tf.compat.v1.Print(stats_step_op,
|
| 498 |
+
[tf.convert_to_tensor(value='step:'),
|
| 499 |
+
self.global_step,
|
| 500 |
+
tf.convert_to_tensor(value='fac step:'),
|
| 501 |
+
self.factor_step,
|
| 502 |
+
tf.convert_to_tensor(value='sgd step:'),
|
| 503 |
+
self.sgd_step,
|
| 504 |
+
tf.convert_to_tensor(value='Accum:'),
|
| 505 |
+
tf.convert_to_tensor(value=accumulate),
|
| 506 |
+
tf.convert_to_tensor(value='Accum coeff:'),
|
| 507 |
+
tf.convert_to_tensor(value=accumulateCoeff),
|
| 508 |
+
tf.convert_to_tensor(value='stat step:'),
|
| 509 |
+
self.stats_step, updateOps[0], updateOps[1]]))
|
| 510 |
+
return [stats_step_op, ]
|
| 511 |
+
|
| 512 |
+
def getStatsEigen(self, stats=None):
|
| 513 |
+
if len(self.stats_eigen) == 0:
|
| 514 |
+
stats_eigen = {}
|
| 515 |
+
if stats is None:
|
| 516 |
+
stats = self.stats
|
| 517 |
+
|
| 518 |
+
tmpEigenCache = {}
|
| 519 |
+
with tf.device('/cpu:0'):
|
| 520 |
+
for var in stats:
|
| 521 |
+
for key in ['fprop_concat_stats', 'bprop_concat_stats']:
|
| 522 |
+
for stats_var in stats[var][key]:
|
| 523 |
+
if stats_var not in tmpEigenCache:
|
| 524 |
+
stats_dim = stats_var.get_shape()[1].value
|
| 525 |
+
e = tf.Variable(tf.ones(
|
| 526 |
+
[stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False)
|
| 527 |
+
Q = tf.Variable(tf.linalg.tensor_diag(tf.ones(
|
| 528 |
+
[stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False)
|
| 529 |
+
stats_eigen[stats_var] = {'e': e, 'Q': Q}
|
| 530 |
+
tmpEigenCache[
|
| 531 |
+
stats_var] = stats_eigen[stats_var]
|
| 532 |
+
else:
|
| 533 |
+
stats_eigen[stats_var] = tmpEigenCache[
|
| 534 |
+
stats_var]
|
| 535 |
+
self.stats_eigen = stats_eigen
|
| 536 |
+
return self.stats_eigen
|
| 537 |
+
|
| 538 |
+
def computeStatsEigen(self):
|
| 539 |
+
""" compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """
|
| 540 |
+
# TO-DO: figure out why this op has delays (possibly moving
|
| 541 |
+
# eigenvectors around?)
|
| 542 |
+
with tf.device('/cpu:0'):
|
| 543 |
+
def removeNone(tensor_list):
|
| 544 |
+
local_list = []
|
| 545 |
+
for item in tensor_list:
|
| 546 |
+
if item is not None:
|
| 547 |
+
local_list.append(item)
|
| 548 |
+
return local_list
|
| 549 |
+
|
| 550 |
+
def copyStats(var_list):
|
| 551 |
+
print("copying stats to buffer tensors before eigen decomp")
|
| 552 |
+
redundant_stats = {}
|
| 553 |
+
copied_list = []
|
| 554 |
+
for item in var_list:
|
| 555 |
+
if item is not None:
|
| 556 |
+
if item not in redundant_stats:
|
| 557 |
+
if self._use_float64:
|
| 558 |
+
redundant_stats[item] = tf.cast(
|
| 559 |
+
tf.identity(item), tf.float64)
|
| 560 |
+
else:
|
| 561 |
+
redundant_stats[item] = tf.identity(item)
|
| 562 |
+
copied_list.append(redundant_stats[item])
|
| 563 |
+
else:
|
| 564 |
+
copied_list.append(None)
|
| 565 |
+
return copied_list
|
| 566 |
+
#stats = [copyStats(self.fStats), copyStats(self.bStats)]
|
| 567 |
+
#stats = [self.fStats, self.bStats]
|
| 568 |
+
|
| 569 |
+
stats_eigen = self.stats_eigen
|
| 570 |
+
computedEigen = {}
|
| 571 |
+
eigen_reverse_lookup = {}
|
| 572 |
+
updateOps = []
|
| 573 |
+
# sync copied stats
|
| 574 |
+
# with tf.control_dependencies(removeNone(stats[0]) +
|
| 575 |
+
# removeNone(stats[1])):
|
| 576 |
+
with tf.control_dependencies([]):
|
| 577 |
+
for stats_var in stats_eigen:
|
| 578 |
+
if stats_var not in computedEigen:
|
| 579 |
+
eigens = tf.linalg.eigh(stats_var)
|
| 580 |
+
e = eigens[0]
|
| 581 |
+
Q = eigens[1]
|
| 582 |
+
if self._use_float64:
|
| 583 |
+
e = tf.cast(e, tf.float32)
|
| 584 |
+
Q = tf.cast(Q, tf.float32)
|
| 585 |
+
updateOps.append(e)
|
| 586 |
+
updateOps.append(Q)
|
| 587 |
+
computedEigen[stats_var] = {'e': e, 'Q': Q}
|
| 588 |
+
eigen_reverse_lookup[e] = stats_eigen[stats_var]['e']
|
| 589 |
+
eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q']
|
| 590 |
+
|
| 591 |
+
self.eigen_reverse_lookup = eigen_reverse_lookup
|
| 592 |
+
self.eigen_update_list = updateOps
|
| 593 |
+
|
| 594 |
+
if KFAC_DEBUG:
|
| 595 |
+
self.eigen_update_list = [item for item in updateOps]
|
| 596 |
+
with tf.control_dependencies(updateOps):
|
| 597 |
+
updateOps.append(tf.compat.v1.Print(tf.constant(
|
| 598 |
+
0.), [tf.convert_to_tensor(value='computed factor eigen')]))
|
| 599 |
+
|
| 600 |
+
return updateOps
|
| 601 |
+
|
| 602 |
+
def applyStatsEigen(self, eigen_list):
|
| 603 |
+
updateOps = []
|
| 604 |
+
print(('updating %d eigenvalue/vectors' % len(eigen_list)))
|
| 605 |
+
for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)):
|
| 606 |
+
stats_eigen_var = self.eigen_reverse_lookup[mark]
|
| 607 |
+
updateOps.append(
|
| 608 |
+
tf.compat.v1.assign(stats_eigen_var, tensor, use_locking=True))
|
| 609 |
+
|
| 610 |
+
with tf.control_dependencies(updateOps):
|
| 611 |
+
factor_step_op = tf.compat.v1.assign_add(self.factor_step, 1)
|
| 612 |
+
updateOps.append(factor_step_op)
|
| 613 |
+
if KFAC_DEBUG:
|
| 614 |
+
updateOps.append(tf.compat.v1.Print(tf.constant(
|
| 615 |
+
0.), [tf.convert_to_tensor(value='updated kfac factors')]))
|
| 616 |
+
return updateOps
|
| 617 |
+
|
| 618 |
+
def getKfacPrecondUpdates(self, gradlist, varlist):
|
| 619 |
+
updatelist = []
|
| 620 |
+
vg = 0.
|
| 621 |
+
|
| 622 |
+
assert len(self.stats) > 0
|
| 623 |
+
assert len(self.stats_eigen) > 0
|
| 624 |
+
assert len(self.factors) > 0
|
| 625 |
+
counter = 0
|
| 626 |
+
|
| 627 |
+
grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}
|
| 628 |
+
|
| 629 |
+
for grad, var in zip(gradlist, varlist):
|
| 630 |
+
GRAD_RESHAPE = False
|
| 631 |
+
GRAD_TRANSPOSE = False
|
| 632 |
+
|
| 633 |
+
fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
|
| 634 |
+
bpropFactoredFishers = self.stats[var]['bprop_concat_stats']
|
| 635 |
+
|
| 636 |
+
if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
|
| 637 |
+
counter += 1
|
| 638 |
+
GRAD_SHAPE = grad.get_shape()
|
| 639 |
+
if len(grad.get_shape()) > 2:
|
| 640 |
+
# reshape conv kernel parameters
|
| 641 |
+
KW = int(grad.get_shape()[0])
|
| 642 |
+
KH = int(grad.get_shape()[1])
|
| 643 |
+
C = int(grad.get_shape()[2])
|
| 644 |
+
D = int(grad.get_shape()[3])
|
| 645 |
+
|
| 646 |
+
if len(fpropFactoredFishers) > 1 and self._channel_fac:
|
| 647 |
+
# reshape conv kernel parameters into tensor
|
| 648 |
+
grad = tf.reshape(grad, [KW * KH, C, D])
|
| 649 |
+
else:
|
| 650 |
+
# reshape conv kernel parameters into 2D grad
|
| 651 |
+
grad = tf.reshape(grad, [-1, D])
|
| 652 |
+
GRAD_RESHAPE = True
|
| 653 |
+
elif len(grad.get_shape()) == 1:
|
| 654 |
+
# reshape bias or 1D parameters
|
| 655 |
+
D = int(grad.get_shape()[0])
|
| 656 |
+
|
| 657 |
+
grad = tf.expand_dims(grad, 0)
|
| 658 |
+
GRAD_RESHAPE = True
|
| 659 |
+
else:
|
| 660 |
+
# 2D parameters
|
| 661 |
+
C = int(grad.get_shape()[0])
|
| 662 |
+
D = int(grad.get_shape()[1])
|
| 663 |
+
|
| 664 |
+
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
|
| 665 |
+
# use homogeneous coordinates only works for 2D grad.
|
| 666 |
+
# TO-DO: figure out how to factorize bias grad
|
| 667 |
+
# stack bias grad
|
| 668 |
+
var_assnBias = self.stats[var]['assnBias']
|
| 669 |
+
grad = tf.concat(
|
| 670 |
+
[grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)
|
| 671 |
+
|
| 672 |
+
# project gradient to eigen space and reshape the eigenvalues
|
| 673 |
+
# for broadcasting
|
| 674 |
+
eigVals = []
|
| 675 |
+
|
| 676 |
+
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
|
| 677 |
+
Q = self.stats_eigen[stats]['Q']
|
| 678 |
+
e = detectMinVal(self.stats_eigen[stats][
|
| 679 |
+
'e'], var, name='act', debug=KFAC_DEBUG)
|
| 680 |
+
|
| 681 |
+
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
|
| 682 |
+
eigVals.append(e)
|
| 683 |
+
grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)
|
| 684 |
+
|
| 685 |
+
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
|
| 686 |
+
Q = self.stats_eigen[stats]['Q']
|
| 687 |
+
e = detectMinVal(self.stats_eigen[stats][
|
| 688 |
+
'e'], var, name='grad', debug=KFAC_DEBUG)
|
| 689 |
+
|
| 690 |
+
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
|
| 691 |
+
eigVals.append(e)
|
| 692 |
+
grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
|
| 693 |
+
##
|
| 694 |
+
|
| 695 |
+
#####
|
| 696 |
+
# whiten using eigenvalues
|
| 697 |
+
weightDecayCoeff = 0.
|
| 698 |
+
if var in self._weight_decay_dict:
|
| 699 |
+
weightDecayCoeff = self._weight_decay_dict[var]
|
| 700 |
+
if KFAC_DEBUG:
|
| 701 |
+
print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))
|
| 702 |
+
|
| 703 |
+
if self._factored_damping:
|
| 704 |
+
if KFAC_DEBUG:
|
| 705 |
+
print(('use factored damping for %s' % (var.name)))
|
| 706 |
+
coeffs = 1.
|
| 707 |
+
num_factors = len(eigVals)
|
| 708 |
+
# compute the ratio of two trace norm of the left and right
|
| 709 |
+
# KFac matrices, and their generalization
|
| 710 |
+
if len(eigVals) == 1:
|
| 711 |
+
damping = self._epsilon + weightDecayCoeff
|
| 712 |
+
else:
|
| 713 |
+
damping = tf.pow(
|
| 714 |
+
self._epsilon + weightDecayCoeff, 1. / num_factors)
|
| 715 |
+
eigVals_tnorm_avg = [tf.reduce_mean(
|
| 716 |
+
input_tensor=tf.abs(e)) for e in eigVals]
|
| 717 |
+
for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
|
| 718 |
+
eig_tnorm_negList = [
|
| 719 |
+
item for item in eigVals_tnorm_avg if item != e_tnorm]
|
| 720 |
+
if len(eigVals) == 1:
|
| 721 |
+
adjustment = 1.
|
| 722 |
+
elif len(eigVals) == 2:
|
| 723 |
+
adjustment = tf.sqrt(
|
| 724 |
+
e_tnorm / eig_tnorm_negList[0])
|
| 725 |
+
else:
|
| 726 |
+
eig_tnorm_negList_prod = reduce(
|
| 727 |
+
lambda x, y: x * y, eig_tnorm_negList)
|
| 728 |
+
adjustment = tf.pow(
|
| 729 |
+
tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
|
| 730 |
+
coeffs *= (e + adjustment * damping)
|
| 731 |
+
else:
|
| 732 |
+
coeffs = 1.
|
| 733 |
+
damping = (self._epsilon + weightDecayCoeff)
|
| 734 |
+
for e in eigVals:
|
| 735 |
+
coeffs *= e
|
| 736 |
+
coeffs += damping
|
| 737 |
+
|
| 738 |
+
#grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
| 739 |
+
|
| 740 |
+
grad /= coeffs
|
| 741 |
+
|
| 742 |
+
#grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
| 743 |
+
#####
|
| 744 |
+
# project gradient back to euclidean space
|
| 745 |
+
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
|
| 746 |
+
Q = self.stats_eigen[stats]['Q']
|
| 747 |
+
grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)
|
| 748 |
+
|
| 749 |
+
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
|
| 750 |
+
Q = self.stats_eigen[stats]['Q']
|
| 751 |
+
grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
|
| 752 |
+
##
|
| 753 |
+
|
| 754 |
+
#grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
| 755 |
+
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
|
| 756 |
+
# use homogeneous coordinates only works for 2D grad.
|
| 757 |
+
# TO-DO: figure out how to factorize bias grad
|
| 758 |
+
# un-stack bias grad
|
| 759 |
+
var_assnBias = self.stats[var]['assnBias']
|
| 760 |
+
C_plus_one = int(grad.get_shape()[0])
|
| 761 |
+
grad_assnBias = tf.reshape(tf.slice(grad,
|
| 762 |
+
begin=[
|
| 763 |
+
C_plus_one - 1, 0],
|
| 764 |
+
size=[1, -1]), var_assnBias.get_shape())
|
| 765 |
+
grad_assnWeights = tf.slice(grad,
|
| 766 |
+
begin=[0, 0],
|
| 767 |
+
size=[C_plus_one - 1, -1])
|
| 768 |
+
grad_dict[var_assnBias] = grad_assnBias
|
| 769 |
+
grad = grad_assnWeights
|
| 770 |
+
|
| 771 |
+
#grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
|
| 772 |
+
if GRAD_RESHAPE:
|
| 773 |
+
grad = tf.reshape(grad, GRAD_SHAPE)
|
| 774 |
+
|
| 775 |
+
grad_dict[var] = grad
|
| 776 |
+
|
| 777 |
+
print(('projecting %d gradient matrices' % counter))
|
| 778 |
+
|
| 779 |
+
for g, var in zip(gradlist, varlist):
|
| 780 |
+
grad = grad_dict[var]
|
| 781 |
+
### clipping ###
|
| 782 |
+
if KFAC_DEBUG:
|
| 783 |
+
print(('apply clipping to %s' % (var.name)))
|
| 784 |
+
tf.compat.v1.Print(grad, [tf.sqrt(tf.reduce_sum(input_tensor=tf.pow(grad, 2)))], "Euclidean norm of new grad")
|
| 785 |
+
local_vg = tf.reduce_sum(input_tensor=grad * g * (self._lr * self._lr))
|
| 786 |
+
vg += local_vg
|
| 787 |
+
|
| 788 |
+
# recale everything
|
| 789 |
+
if KFAC_DEBUG:
|
| 790 |
+
print('apply vFv clipping')
|
| 791 |
+
|
| 792 |
+
scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
|
| 793 |
+
if KFAC_DEBUG:
|
| 794 |
+
scaling = tf.compat.v1.Print(scaling, [tf.convert_to_tensor(
|
| 795 |
+
value='clip: '), scaling, tf.convert_to_tensor(value=' vFv: '), vg])
|
| 796 |
+
with tf.control_dependencies([tf.compat.v1.assign(self.vFv, vg)]):
|
| 797 |
+
updatelist = [grad_dict[var] for var in varlist]
|
| 798 |
+
for i, item in enumerate(updatelist):
|
| 799 |
+
updatelist[i] = scaling * item
|
| 800 |
+
|
| 801 |
+
return updatelist
|
| 802 |
+
|
| 803 |
+
def compute_gradients(self, loss, var_list=None):
|
| 804 |
+
varlist = var_list
|
| 805 |
+
if varlist is None:
|
| 806 |
+
varlist = tf.compat.v1.trainable_variables()
|
| 807 |
+
g = tf.gradients(ys=loss, xs=varlist)
|
| 808 |
+
|
| 809 |
+
return [(a, b) for a, b in zip(g, varlist)]
|
| 810 |
+
|
| 811 |
+
def apply_gradients_kfac(self, grads):
|
| 812 |
+
g, varlist = list(zip(*grads))
|
| 813 |
+
|
| 814 |
+
if len(self.stats_eigen) == 0:
|
| 815 |
+
self.getStatsEigen()
|
| 816 |
+
|
| 817 |
+
qr = None
|
| 818 |
+
# launch eigen-decomp on a queue thread
|
| 819 |
+
if self._async:
|
| 820 |
+
print('Use async eigen decomp')
|
| 821 |
+
# get a list of factor loading tensors
|
| 822 |
+
factorOps_dummy = self.computeStatsEigen()
|
| 823 |
+
|
| 824 |
+
# define a queue for the list of factor loading tensors
|
| 825 |
+
queue = tf.queue.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
|
| 826 |
+
item.get_shape() for item in factorOps_dummy])
|
| 827 |
+
enqueue_op = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
|
| 828 |
+
value=0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: queue.enqueue(self.computeStatsEigen()), false_fn=tf.no_op)
|
| 829 |
+
|
| 830 |
+
def dequeue_op():
|
| 831 |
+
return queue.dequeue()
|
| 832 |
+
|
| 833 |
+
qr = tf.compat.v1.train.QueueRunner(queue, [enqueue_op])
|
| 834 |
+
|
| 835 |
+
updateOps = []
|
| 836 |
+
global_step_op = tf.compat.v1.assign_add(self.global_step, 1)
|
| 837 |
+
updateOps.append(global_step_op)
|
| 838 |
+
|
| 839 |
+
with tf.control_dependencies([global_step_op]):
|
| 840 |
+
|
| 841 |
+
# compute updates
|
| 842 |
+
assert self._update_stats_op != None
|
| 843 |
+
updateOps.append(self._update_stats_op)
|
| 844 |
+
dependency_list = []
|
| 845 |
+
if not self._async:
|
| 846 |
+
dependency_list.append(self._update_stats_op)
|
| 847 |
+
|
| 848 |
+
with tf.control_dependencies(dependency_list):
|
| 849 |
+
def no_op_wrapper():
|
| 850 |
+
return tf.group(*[tf.compat.v1.assign_add(self.cold_step, 1)])
|
| 851 |
+
|
| 852 |
+
if not self._async:
|
| 853 |
+
# synchronous eigen-decomp updates
|
| 854 |
+
updateFactorOps = tf.cond(pred=tf.logical_and(tf.equal(tf.math.floormod(self.stats_step, self._kfac_update),
|
| 855 |
+
tf.convert_to_tensor(value=0)),
|
| 856 |
+
tf.greater_equal(self.stats_step, self._stats_accum_iter)), true_fn=lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), false_fn=no_op_wrapper)
|
| 857 |
+
else:
|
| 858 |
+
# asynchronous eigen-decomp updates using queue
|
| 859 |
+
updateFactorOps = tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter),
|
| 860 |
+
true_fn=lambda: tf.cond(pred=tf.equal(queue.size(), tf.convert_to_tensor(value=0)),
|
| 861 |
+
true_fn=tf.no_op,
|
| 862 |
+
|
| 863 |
+
false_fn=lambda: tf.group(
|
| 864 |
+
*self.applyStatsEigen(dequeue_op())),
|
| 865 |
+
),
|
| 866 |
+
false_fn=no_op_wrapper)
|
| 867 |
+
|
| 868 |
+
updateOps.append(updateFactorOps)
|
| 869 |
+
|
| 870 |
+
with tf.control_dependencies([updateFactorOps]):
|
| 871 |
+
def gradOp():
|
| 872 |
+
return list(g)
|
| 873 |
+
|
| 874 |
+
def getKfacGradOp():
|
| 875 |
+
return self.getKfacPrecondUpdates(g, varlist)
|
| 876 |
+
u = tf.cond(pred=tf.greater(self.factor_step,
|
| 877 |
+
tf.convert_to_tensor(value=0)), true_fn=getKfacGradOp, false_fn=gradOp)
|
| 878 |
+
|
| 879 |
+
optim = tf.compat.v1.train.MomentumOptimizer(
|
| 880 |
+
self._lr * (1. - self._momentum), self._momentum)
|
| 881 |
+
#optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)
|
| 882 |
+
|
| 883 |
+
def optimOp():
|
| 884 |
+
def updateOptimOp():
|
| 885 |
+
if self._full_stats_init:
|
| 886 |
+
return tf.cond(pred=tf.greater(self.factor_step, tf.convert_to_tensor(value=0)), true_fn=lambda: optim.apply_gradients(list(zip(u, varlist))), false_fn=tf.no_op)
|
| 887 |
+
else:
|
| 888 |
+
return optim.apply_gradients(list(zip(u, varlist)))
|
| 889 |
+
if self._full_stats_init:
|
| 890 |
+
return tf.cond(pred=tf.greater_equal(self.stats_step, self._stats_accum_iter), true_fn=updateOptimOp, false_fn=tf.no_op)
|
| 891 |
+
else:
|
| 892 |
+
return tf.cond(pred=tf.greater_equal(self.sgd_step, self._cold_iter), true_fn=updateOptimOp, false_fn=tf.no_op)
|
| 893 |
+
updateOps.append(optimOp())
|
| 894 |
+
|
| 895 |
+
return tf.group(*updateOps), qr
|
| 896 |
+
|
| 897 |
+
def apply_gradients(self, grads):
|
| 898 |
+
coldOptim = tf.compat.v1.train.MomentumOptimizer(
|
| 899 |
+
self._cold_lr, self._momentum)
|
| 900 |
+
|
| 901 |
+
def coldSGDstart():
|
| 902 |
+
sgd_grads, sgd_var = zip(*grads)
|
| 903 |
+
|
| 904 |
+
if self.max_grad_norm != None:
|
| 905 |
+
sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm)
|
| 906 |
+
|
| 907 |
+
sgd_grads = list(zip(sgd_grads,sgd_var))
|
| 908 |
+
|
| 909 |
+
sgd_step_op = tf.compat.v1.assign_add(self.sgd_step, 1)
|
| 910 |
+
coldOptim_op = coldOptim.apply_gradients(sgd_grads)
|
| 911 |
+
if KFAC_DEBUG:
|
| 912 |
+
with tf.control_dependencies([sgd_step_op, coldOptim_op]):
|
| 913 |
+
sgd_step_op = tf.compat.v1.Print(
|
| 914 |
+
sgd_step_op, [self.sgd_step, tf.convert_to_tensor(value='doing cold sgd step')])
|
| 915 |
+
return tf.group(*[sgd_step_op, coldOptim_op])
|
| 916 |
+
|
| 917 |
+
kfacOptim_op, qr = self.apply_gradients_kfac(grads)
|
| 918 |
+
|
| 919 |
+
def warmKFACstart():
|
| 920 |
+
return kfacOptim_op
|
| 921 |
+
|
| 922 |
+
return tf.cond(pred=tf.greater(self.sgd_step, self._cold_iter), true_fn=warmKFACstart, false_fn=coldSGDstart), qr
|
| 923 |
+
|
| 924 |
+
def minimize(self, loss, loss_sampled, var_list=None):
|
| 925 |
+
grads = self.compute_gradients(loss, var_list=var_list)
|
| 926 |
+
update_stats_op = self.compute_and_apply_stats(
|
| 927 |
+
loss_sampled, var_list=var_list)
|
| 928 |
+
return self.apply_gradients(grads)
|
baselines/acktr/kfac_utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
def gmatmul(a, b, transpose_a=False, transpose_b=False, reduce_dim=None):
|
| 4 |
+
assert reduce_dim is not None
|
| 5 |
+
|
| 6 |
+
# weird batch matmul
|
| 7 |
+
if len(a.get_shape()) == 2 and len(b.get_shape()) > 2:
|
| 8 |
+
# reshape reduce_dim to the left most dim in b
|
| 9 |
+
b_shape = b.get_shape()
|
| 10 |
+
if reduce_dim != 0:
|
| 11 |
+
b_dims = list(range(len(b_shape)))
|
| 12 |
+
b_dims.remove(reduce_dim)
|
| 13 |
+
b_dims.insert(0, reduce_dim)
|
| 14 |
+
b = tf.transpose(a=b, perm=b_dims)
|
| 15 |
+
b_t_shape = b.get_shape()
|
| 16 |
+
b = tf.reshape(b, [int(b_shape[reduce_dim]), -1])
|
| 17 |
+
result = tf.matmul(a, b, transpose_a=transpose_a,
|
| 18 |
+
transpose_b=transpose_b)
|
| 19 |
+
result = tf.reshape(result, b_t_shape)
|
| 20 |
+
if reduce_dim != 0:
|
| 21 |
+
b_dims = list(range(len(b_shape)))
|
| 22 |
+
b_dims.remove(0)
|
| 23 |
+
b_dims.insert(reduce_dim, 0)
|
| 24 |
+
result = tf.transpose(a=result, perm=b_dims)
|
| 25 |
+
return result
|
| 26 |
+
|
| 27 |
+
elif len(a.get_shape()) > 2 and len(b.get_shape()) == 2:
|
| 28 |
+
# reshape reduce_dim to the right most dim in a
|
| 29 |
+
a_shape = a.get_shape()
|
| 30 |
+
outter_dim = len(a_shape) - 1
|
| 31 |
+
reduce_dim = len(a_shape) - reduce_dim - 1
|
| 32 |
+
if reduce_dim != outter_dim:
|
| 33 |
+
a_dims = list(range(len(a_shape)))
|
| 34 |
+
a_dims.remove(reduce_dim)
|
| 35 |
+
a_dims.insert(outter_dim, reduce_dim)
|
| 36 |
+
a = tf.transpose(a=a, perm=a_dims)
|
| 37 |
+
a_t_shape = a.get_shape()
|
| 38 |
+
a = tf.reshape(a, [-1, int(a_shape[reduce_dim])])
|
| 39 |
+
result = tf.matmul(a, b, transpose_a=transpose_a,
|
| 40 |
+
transpose_b=transpose_b)
|
| 41 |
+
result = tf.reshape(result, a_t_shape)
|
| 42 |
+
if reduce_dim != outter_dim:
|
| 43 |
+
a_dims = list(range(len(a_shape)))
|
| 44 |
+
a_dims.remove(outter_dim)
|
| 45 |
+
a_dims.insert(reduce_dim, outter_dim)
|
| 46 |
+
result = tf.transpose(a=result, perm=a_dims)
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
elif len(a.get_shape()) == 2 and len(b.get_shape()) == 2:
|
| 50 |
+
return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b)
|
| 51 |
+
|
| 52 |
+
assert False, 'something went wrong'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def clipoutNeg(vec, threshold=1e-6):
|
| 56 |
+
mask = tf.cast(vec > threshold, tf.float32)
|
| 57 |
+
return mask * vec
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False):
|
| 61 |
+
eigen_min = tf.reduce_min(input_tensor=input_mat)
|
| 62 |
+
eigen_max = tf.reduce_max(input_tensor=input_mat)
|
| 63 |
+
eigen_ratio = eigen_max / eigen_min
|
| 64 |
+
input_mat_clipped = clipoutNeg(input_mat, threshold)
|
| 65 |
+
|
| 66 |
+
if debug:
|
| 67 |
+
input_mat_clipped = tf.cond(pred=tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), true_fn=lambda: input_mat_clipped, false_fn=lambda: tf.compat.v1.Print(
|
| 68 |
+
input_mat_clipped, [tf.convert_to_tensor(value='screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(value=var.name), eigen_min, eigen_max, eigen_ratio]))
|
| 69 |
+
|
| 70 |
+
return input_mat_clipped
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def factorReshape(Q, e, grad, facIndx=0, ftype='act'):
|
| 74 |
+
grad_shape = grad.get_shape()
|
| 75 |
+
if ftype == 'act':
|
| 76 |
+
assert e.get_shape()[0] == grad_shape[facIndx]
|
| 77 |
+
expanded_shape = [1, ] * len(grad_shape)
|
| 78 |
+
expanded_shape[facIndx] = -1
|
| 79 |
+
e = tf.reshape(e, expanded_shape)
|
| 80 |
+
if ftype == 'grad':
|
| 81 |
+
assert e.get_shape()[0] == grad_shape[len(grad_shape) - facIndx - 1]
|
| 82 |
+
expanded_shape = [1, ] * len(grad_shape)
|
| 83 |
+
expanded_shape[len(grad_shape) - facIndx - 1] = -1
|
| 84 |
+
e = tf.reshape(e, expanded_shape)
|
| 85 |
+
|
| 86 |
+
return Q, e
|
baselines/acktr/utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
def dense(x, size, name, weight_init=None, bias_init=0, weight_loss_dict=None, reuse=None):
|
| 4 |
+
with tf.compat.v1.variable_scope(name, reuse=reuse):
|
| 5 |
+
assert (len(tf.compat.v1.get_variable_scope().name.split('/')) == 2)
|
| 6 |
+
|
| 7 |
+
w = tf.compat.v1.get_variable("w", [x.get_shape()[1], size], initializer=weight_init)
|
| 8 |
+
b = tf.compat.v1.get_variable("b", [size], initializer=tf.compat.v1.constant_initializer(bias_init))
|
| 9 |
+
weight_decay_fc = 3e-4
|
| 10 |
+
|
| 11 |
+
if weight_loss_dict is not None:
|
| 12 |
+
weight_decay = tf.multiply(tf.nn.l2_loss(w), weight_decay_fc, name='weight_decay_loss')
|
| 13 |
+
if weight_loss_dict is not None:
|
| 14 |
+
weight_loss_dict[w] = weight_decay_fc
|
| 15 |
+
weight_loss_dict[b] = 0.0
|
| 16 |
+
|
| 17 |
+
tf.compat.v1.add_to_collection(tf.compat.v1.get_variable_scope().name.split('/')[0] + '_' + 'losses', weight_decay)
|
| 18 |
+
|
| 19 |
+
return tf.nn.bias_add(tf.matmul(x, w), b)
|
| 20 |
+
|
| 21 |
+
def kl_div(action_dist1, action_dist2, action_size):
|
| 22 |
+
mean1, std1 = action_dist1[:, :action_size], action_dist1[:, action_size:]
|
| 23 |
+
mean2, std2 = action_dist2[:, :action_size], action_dist2[:, action_size:]
|
| 24 |
+
|
| 25 |
+
numerator = tf.square(mean1 - mean2) + tf.square(std1) - tf.square(std2)
|
| 26 |
+
denominator = 2 * tf.square(std2) + 1e-8
|
| 27 |
+
return tf.reduce_sum(
|
| 28 |
+
input_tensor=numerator/denominator + tf.math.log(std2) - tf.math.log(std1),axis=-1)
|
baselines/bench/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa F403
|
| 2 |
+
from baselines.bench.benchmarks import *
|
| 3 |
+
from baselines.bench.monitor import *
|
baselines/bench/benchmarks.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
|
| 5 |
+
_atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders']
|
| 6 |
+
_atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture']
|
| 7 |
+
|
| 8 |
+
_BENCHMARKS = []
|
| 9 |
+
|
| 10 |
+
remove_version_re = re.compile(r'-v\d+$')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def register_benchmark(benchmark):
|
| 14 |
+
for b in _BENCHMARKS:
|
| 15 |
+
if b['name'] == benchmark['name']:
|
| 16 |
+
raise ValueError('Benchmark with name %s already registered!' % b['name'])
|
| 17 |
+
|
| 18 |
+
# automatically add a description if it is not present
|
| 19 |
+
if 'tasks' in benchmark:
|
| 20 |
+
for t in benchmark['tasks']:
|
| 21 |
+
if 'desc' not in t:
|
| 22 |
+
t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
|
| 23 |
+
_BENCHMARKS.append(benchmark)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def list_benchmarks():
|
| 27 |
+
return [b['name'] for b in _BENCHMARKS]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_benchmark(benchmark_name):
|
| 31 |
+
for b in _BENCHMARKS:
|
| 32 |
+
if b['name'] == benchmark_name:
|
| 33 |
+
return b
|
| 34 |
+
raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks()))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_task(benchmark, env_id):
|
| 38 |
+
"""Get a task by env_id. Return None if the benchmark doesn't have the env"""
|
| 39 |
+
return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def find_task_for_env_id_in_any_benchmark(env_id):
|
| 43 |
+
for bm in _BENCHMARKS:
|
| 44 |
+
for task in bm["tasks"]:
|
| 45 |
+
if task["env_id"] == env_id:
|
| 46 |
+
return bm, task
|
| 47 |
+
return None, None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_ATARI_SUFFIX = 'NoFrameskip-v4'
|
| 51 |
+
|
| 52 |
+
register_benchmark({
|
| 53 |
+
'name': 'Atari50M',
|
| 54 |
+
'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 50M timesteps',
|
| 55 |
+
'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(50e6)} for _game in _atari7]
|
| 56 |
+
})
|
| 57 |
+
|
| 58 |
+
register_benchmark({
|
| 59 |
+
'name': 'Atari10M',
|
| 60 |
+
'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps',
|
| 61 |
+
'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 6, 'num_timesteps': int(10e6)} for _game in _atari7]
|
| 62 |
+
})
|
| 63 |
+
|
| 64 |
+
register_benchmark({
|
| 65 |
+
'name': 'Atari1Hr',
|
| 66 |
+
'description': '7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime',
|
| 67 |
+
'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_seconds': 60 * 60} for _game in _atari7]
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
register_benchmark({
|
| 71 |
+
'name': 'AtariExploration10M',
|
| 72 |
+
'description': '7 Atari games emphasizing exploration, with pixel observations, 10M timesteps',
|
| 73 |
+
'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atariexpl7]
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# MuJoCo
|
| 78 |
+
|
| 79 |
+
_mujocosmall = [
|
| 80 |
+
'InvertedDoublePendulum-v2', 'InvertedPendulum-v2',
|
| 81 |
+
'HalfCheetah-v2', 'Hopper-v2', 'Walker2d-v2',
|
| 82 |
+
'Reacher-v2', 'Swimmer-v2']
|
| 83 |
+
register_benchmark({
|
| 84 |
+
'name': 'Mujoco1M',
|
| 85 |
+
'description': 'Some small 2D MuJoCo tasks, run for 1M timesteps',
|
| 86 |
+
'tasks': [{'env_id': _envid, 'trials': 6, 'num_timesteps': int(1e6)} for _envid in _mujocosmall]
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
register_benchmark({
|
| 90 |
+
'name': 'MujocoWalkers',
|
| 91 |
+
'description': 'MuJoCo forward walkers, run for 8M, humanoid 100M',
|
| 92 |
+
'tasks': [
|
| 93 |
+
{'env_id': "Hopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 94 |
+
{'env_id': "Walker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 95 |
+
{'env_id': "Humanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000},
|
| 96 |
+
]
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
# Bullet
|
| 100 |
+
_bulletsmall = [
|
| 101 |
+
'InvertedDoublePendulum', 'InvertedPendulum', 'HalfCheetah', 'Reacher', 'Walker2D', 'Hopper', 'Ant'
|
| 102 |
+
]
|
| 103 |
+
_bulletsmall = [e + 'BulletEnv-v0' for e in _bulletsmall]
|
| 104 |
+
|
| 105 |
+
register_benchmark({
|
| 106 |
+
'name': 'Bullet1M',
|
| 107 |
+
'description': '6 mujoco-like tasks from bullet, 1M steps',
|
| 108 |
+
'tasks': [{'env_id': e, 'trials': 6, 'num_timesteps': int(1e6)} for e in _bulletsmall]
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Roboschool
|
| 113 |
+
|
| 114 |
+
register_benchmark({
|
| 115 |
+
'name': 'Roboschool8M',
|
| 116 |
+
'description': 'Small 2D tasks, up to 30 minutes to complete on 8 cores',
|
| 117 |
+
'tasks': [
|
| 118 |
+
{'env_id': "RoboschoolReacher-v1", 'trials': 4, 'num_timesteps': 2 * 1000000},
|
| 119 |
+
{'env_id': "RoboschoolAnt-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 120 |
+
{'env_id': "RoboschoolHalfCheetah-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 121 |
+
{'env_id': "RoboschoolHopper-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 122 |
+
{'env_id': "RoboschoolWalker2d-v1", 'trials': 4, 'num_timesteps': 8 * 1000000},
|
| 123 |
+
]
|
| 124 |
+
})
|
| 125 |
+
register_benchmark({
|
| 126 |
+
'name': 'RoboschoolHarder',
|
| 127 |
+
'description': 'Test your might!!! Up to 12 hours on 32 cores',
|
| 128 |
+
'tasks': [
|
| 129 |
+
{'env_id': "RoboschoolHumanoid-v1", 'trials': 4, 'num_timesteps': 100 * 1000000},
|
| 130 |
+
{'env_id': "RoboschoolHumanoidFlagrun-v1", 'trials': 4, 'num_timesteps': 200 * 1000000},
|
| 131 |
+
{'env_id': "RoboschoolHumanoidFlagrunHarder-v1", 'trials': 4, 'num_timesteps': 400 * 1000000},
|
| 132 |
+
]
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
# Other
|
| 136 |
+
|
| 137 |
+
_atari50 = [ # actually 47
|
| 138 |
+
'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids',
|
| 139 |
+
'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling',
|
| 140 |
+
'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber',
|
| 141 |
+
'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway',
|
| 142 |
+
'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond',
|
| 143 |
+
'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman',
|
| 144 |
+
'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert',
|
| 145 |
+
'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders', 'StarGunner',
|
| 146 |
+
'Tennis', 'TimePilot', 'Tutankham', 'UpNDown', 'Venture',
|
| 147 |
+
'VideoPinball', 'WizardOfWor', 'Zaxxon',
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
register_benchmark({
|
| 151 |
+
'name': 'Atari50_10M',
|
| 152 |
+
'description': '47 Atari games from Mnih et al. (2013), with pixel observations, 10M timesteps',
|
| 153 |
+
'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atari50]
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
# HER DDPG
|
| 157 |
+
|
| 158 |
+
_fetch_tasks = ['FetchReach-v1', 'FetchPush-v1', 'FetchSlide-v1']
|
| 159 |
+
register_benchmark({
|
| 160 |
+
'name': 'Fetch1M',
|
| 161 |
+
'description': 'Fetch* benchmarks for 1M timesteps',
|
| 162 |
+
'tasks': [{'trials': 6, 'env_id': env_id, 'num_timesteps': int(1e6)} for env_id in _fetch_tasks]
|
| 163 |
+
})
|
| 164 |
+
|
baselines/bench/monitor.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ['Monitor', 'get_monitor_files', 'load_results']
|
| 2 |
+
|
| 3 |
+
from gym.core import Wrapper
|
| 4 |
+
import time
|
| 5 |
+
from glob import glob
|
| 6 |
+
import csv
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
class Monitor(Wrapper):
|
| 11 |
+
EXT = "monitor.csv"
|
| 12 |
+
f = None
|
| 13 |
+
|
| 14 |
+
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
|
| 15 |
+
Wrapper.__init__(self, env=env)
|
| 16 |
+
self.tstart = time.time()
|
| 17 |
+
if filename:
|
| 18 |
+
self.results_writer = ResultsWriter(filename,
|
| 19 |
+
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
|
| 20 |
+
extra_keys=reset_keywords + info_keywords
|
| 21 |
+
)
|
| 22 |
+
else:
|
| 23 |
+
self.results_writer = None
|
| 24 |
+
self.reset_keywords = reset_keywords
|
| 25 |
+
self.info_keywords = info_keywords
|
| 26 |
+
self.allow_early_resets = allow_early_resets
|
| 27 |
+
self.rewards = None
|
| 28 |
+
self.needs_reset = True
|
| 29 |
+
self.episode_rewards = []
|
| 30 |
+
self.episode_lengths = []
|
| 31 |
+
self.episode_times = []
|
| 32 |
+
self.total_steps = 0
|
| 33 |
+
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
|
| 34 |
+
|
| 35 |
+
def reset(self, **kwargs):
|
| 36 |
+
self.reset_state()
|
| 37 |
+
for k in self.reset_keywords:
|
| 38 |
+
v = kwargs.get(k)
|
| 39 |
+
if v is None:
|
| 40 |
+
raise ValueError('Expected you to pass kwarg %s into reset'%k)
|
| 41 |
+
self.current_reset_info[k] = v
|
| 42 |
+
return self.env.reset(**kwargs)
|
| 43 |
+
|
| 44 |
+
def reset_state(self):
|
| 45 |
+
if not self.allow_early_resets and not self.needs_reset:
|
| 46 |
+
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
|
| 47 |
+
self.rewards = []
|
| 48 |
+
self.needs_reset = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def step(self, action):
|
| 52 |
+
if self.needs_reset:
|
| 53 |
+
raise RuntimeError("Tried to step environment that needs reset")
|
| 54 |
+
ob, rew, done, info = self.env.step(action)
|
| 55 |
+
self.update(ob, rew, done, info)
|
| 56 |
+
return (ob, rew, done, info)
|
| 57 |
+
|
| 58 |
+
def update(self, ob, rew, done, info):
|
| 59 |
+
self.rewards.append(rew)
|
| 60 |
+
if done:
|
| 61 |
+
self.needs_reset = True
|
| 62 |
+
eprew = sum(self.rewards)
|
| 63 |
+
eplen = len(self.rewards)
|
| 64 |
+
epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
|
| 65 |
+
for k in self.info_keywords:
|
| 66 |
+
epinfo[k] = info[k]
|
| 67 |
+
self.episode_rewards.append(eprew)
|
| 68 |
+
self.episode_lengths.append(eplen)
|
| 69 |
+
self.episode_times.append(time.time() - self.tstart)
|
| 70 |
+
epinfo.update(self.current_reset_info)
|
| 71 |
+
if self.results_writer:
|
| 72 |
+
self.results_writer.write_row(epinfo)
|
| 73 |
+
assert isinstance(info, dict)
|
| 74 |
+
if isinstance(info, dict):
|
| 75 |
+
info['episode'] = epinfo
|
| 76 |
+
|
| 77 |
+
self.total_steps += 1
|
| 78 |
+
|
| 79 |
+
def close(self):
|
| 80 |
+
if self.f is not None:
|
| 81 |
+
self.f.close()
|
| 82 |
+
|
| 83 |
+
def get_total_steps(self):
|
| 84 |
+
return self.total_steps
|
| 85 |
+
|
| 86 |
+
def get_episode_rewards(self):
|
| 87 |
+
return self.episode_rewards
|
| 88 |
+
|
| 89 |
+
def get_episode_lengths(self):
|
| 90 |
+
return self.episode_lengths
|
| 91 |
+
|
| 92 |
+
def get_episode_times(self):
|
| 93 |
+
return self.episode_times
|
| 94 |
+
|
| 95 |
+
class LoadMonitorResultsError(Exception):
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ResultsWriter(object):
|
| 100 |
+
def __init__(self, filename, header='', extra_keys=()):
|
| 101 |
+
self.extra_keys = extra_keys
|
| 102 |
+
assert filename is not None
|
| 103 |
+
if not filename.endswith(Monitor.EXT):
|
| 104 |
+
if osp.isdir(filename):
|
| 105 |
+
filename = osp.join(filename, Monitor.EXT)
|
| 106 |
+
else:
|
| 107 |
+
filename = filename + "." + Monitor.EXT
|
| 108 |
+
self.f = open(filename, "wt")
|
| 109 |
+
if isinstance(header, dict):
|
| 110 |
+
header = '# {} \n'.format(json.dumps(header))
|
| 111 |
+
self.f.write(header)
|
| 112 |
+
self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+tuple(extra_keys))
|
| 113 |
+
self.logger.writeheader()
|
| 114 |
+
self.f.flush()
|
| 115 |
+
|
| 116 |
+
def write_row(self, epinfo):
|
| 117 |
+
if self.logger:
|
| 118 |
+
self.logger.writerow(epinfo)
|
| 119 |
+
self.f.flush()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_monitor_files(dir):
|
| 123 |
+
return glob(osp.join(dir, "*" + Monitor.EXT))
|
| 124 |
+
|
| 125 |
+
def load_results(dir):
|
| 126 |
+
import pandas
|
| 127 |
+
monitor_files = (
|
| 128 |
+
glob(osp.join(dir, "*monitor.json")) +
|
| 129 |
+
glob(osp.join(dir, "*monitor.csv"))) # get both csv and (old) json files
|
| 130 |
+
if not monitor_files:
|
| 131 |
+
raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir))
|
| 132 |
+
dfs = []
|
| 133 |
+
headers = []
|
| 134 |
+
for fname in monitor_files:
|
| 135 |
+
with open(fname, 'rt') as fh:
|
| 136 |
+
if fname.endswith('csv'):
|
| 137 |
+
firstline = fh.readline()
|
| 138 |
+
if not firstline:
|
| 139 |
+
continue
|
| 140 |
+
assert firstline[0] == '#'
|
| 141 |
+
header = json.loads(firstline[1:])
|
| 142 |
+
df = pandas.read_csv(fh, index_col=None)
|
| 143 |
+
headers.append(header)
|
| 144 |
+
elif fname.endswith('json'): # Deprecated json format
|
| 145 |
+
episodes = []
|
| 146 |
+
lines = fh.readlines()
|
| 147 |
+
header = json.loads(lines[0])
|
| 148 |
+
headers.append(header)
|
| 149 |
+
for line in lines[1:]:
|
| 150 |
+
episode = json.loads(line)
|
| 151 |
+
episodes.append(episode)
|
| 152 |
+
df = pandas.DataFrame(episodes)
|
| 153 |
+
else:
|
| 154 |
+
assert 0, 'unreachable'
|
| 155 |
+
df['t'] += header['t_start']
|
| 156 |
+
dfs.append(df)
|
| 157 |
+
df = pandas.concat(dfs)
|
| 158 |
+
df.sort_values('t', inplace=True)
|
| 159 |
+
df.reset_index(inplace=True)
|
| 160 |
+
df['t'] -= min(header['t_start'] for header in headers)
|
| 161 |
+
df.headers = headers # HACK to preserve backwards compatibility
|
| 162 |
+
return df
|
baselines/bench/test_monitor.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .monitor import Monitor
|
| 2 |
+
import gym
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
def test_monitor():
|
| 6 |
+
import pandas
|
| 7 |
+
import os
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
env = gym.make("CartPole-v1")
|
| 11 |
+
env.seed(0)
|
| 12 |
+
mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
|
| 13 |
+
menv = Monitor(env, mon_file)
|
| 14 |
+
menv.reset()
|
| 15 |
+
for _ in range(1000):
|
| 16 |
+
_, _, done, _ = menv.step(0)
|
| 17 |
+
if done:
|
| 18 |
+
menv.reset()
|
| 19 |
+
|
| 20 |
+
f = open(mon_file, 'rt')
|
| 21 |
+
|
| 22 |
+
firstline = f.readline()
|
| 23 |
+
assert firstline.startswith('#')
|
| 24 |
+
metadata = json.loads(firstline[1:])
|
| 25 |
+
assert metadata['env_id'] == "CartPole-v1"
|
| 26 |
+
assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"
|
| 27 |
+
|
| 28 |
+
last_logline = pandas.read_csv(f, index_col=None)
|
| 29 |
+
assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
|
| 30 |
+
f.close()
|
| 31 |
+
os.remove(mon_file)
|
baselines/common/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa F403
|
| 2 |
+
from baselines.common.console_util import *
|
| 3 |
+
from baselines.common.dataset import Dataset
|
| 4 |
+
from baselines.common.math_util import *
|
| 5 |
+
from baselines.common.misc_util import *
|
baselines/common/atari_wrappers.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
os.environ.setdefault('PATH', '')
|
| 4 |
+
from collections import deque
|
| 5 |
+
import gym
|
| 6 |
+
from gym import spaces
|
| 7 |
+
import cv2
|
| 8 |
+
cv2.ocl.setUseOpenCL(False)
|
| 9 |
+
from .wrappers import TimeLimit
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NoopResetEnv(gym.Wrapper):
|
| 13 |
+
def __init__(self, env, noop_max=30):
|
| 14 |
+
"""Sample initial states by taking random number of no-ops on reset.
|
| 15 |
+
No-op is assumed to be action 0.
|
| 16 |
+
"""
|
| 17 |
+
gym.Wrapper.__init__(self, env)
|
| 18 |
+
self.noop_max = noop_max
|
| 19 |
+
self.override_num_noops = None
|
| 20 |
+
self.noop_action = 0
|
| 21 |
+
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
| 22 |
+
|
| 23 |
+
def reset(self, **kwargs):
|
| 24 |
+
""" Do no-op action for a number of steps in [1, noop_max]."""
|
| 25 |
+
self.env.reset(**kwargs)
|
| 26 |
+
if self.override_num_noops is not None:
|
| 27 |
+
noops = self.override_num_noops
|
| 28 |
+
else:
|
| 29 |
+
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
|
| 30 |
+
assert noops > 0
|
| 31 |
+
obs = None
|
| 32 |
+
for _ in range(noops):
|
| 33 |
+
obs, _, done, _ = self.env.step(self.noop_action)
|
| 34 |
+
if done:
|
| 35 |
+
obs = self.env.reset(**kwargs)
|
| 36 |
+
return obs
|
| 37 |
+
|
| 38 |
+
def step(self, ac):
|
| 39 |
+
return self.env.step(ac)
|
| 40 |
+
|
| 41 |
+
class FireResetEnv(gym.Wrapper):
|
| 42 |
+
def __init__(self, env):
|
| 43 |
+
"""Take action on reset for environments that are fixed until firing."""
|
| 44 |
+
gym.Wrapper.__init__(self, env)
|
| 45 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
| 46 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 47 |
+
|
| 48 |
+
def reset(self, **kwargs):
|
| 49 |
+
self.env.reset(**kwargs)
|
| 50 |
+
obs, _, done, _ = self.env.step(1)
|
| 51 |
+
if done:
|
| 52 |
+
self.env.reset(**kwargs)
|
| 53 |
+
obs, _, done, _ = self.env.step(2)
|
| 54 |
+
if done:
|
| 55 |
+
self.env.reset(**kwargs)
|
| 56 |
+
return obs
|
| 57 |
+
|
| 58 |
+
def step(self, ac):
|
| 59 |
+
return self.env.step(ac)
|
| 60 |
+
|
| 61 |
+
class EpisodicLifeEnv(gym.Wrapper):
|
| 62 |
+
def __init__(self, env):
|
| 63 |
+
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
| 64 |
+
Done by DeepMind for the DQN and co. since it helps value estimation.
|
| 65 |
+
"""
|
| 66 |
+
gym.Wrapper.__init__(self, env)
|
| 67 |
+
self.lives = 0
|
| 68 |
+
self.was_real_done = True
|
| 69 |
+
|
| 70 |
+
def step(self, action):
|
| 71 |
+
obs, reward, done, info = self.env.step(action)
|
| 72 |
+
self.was_real_done = done
|
| 73 |
+
# check current lives, make loss of life terminal,
|
| 74 |
+
# then update lives to handle bonus lives
|
| 75 |
+
lives = self.env.unwrapped.ale.lives()
|
| 76 |
+
if lives < self.lives and lives > 0:
|
| 77 |
+
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
| 78 |
+
# so it's important to keep lives > 0, so that we only reset once
|
| 79 |
+
# the environment advertises done.
|
| 80 |
+
done = True
|
| 81 |
+
self.lives = lives
|
| 82 |
+
return obs, reward, done, info
|
| 83 |
+
|
| 84 |
+
def reset(self, **kwargs):
|
| 85 |
+
"""Reset only when lives are exhausted.
|
| 86 |
+
This way all states are still reachable even though lives are episodic,
|
| 87 |
+
and the learner need not know about any of this behind-the-scenes.
|
| 88 |
+
"""
|
| 89 |
+
if self.was_real_done:
|
| 90 |
+
obs = self.env.reset(**kwargs)
|
| 91 |
+
else:
|
| 92 |
+
# no-op step to advance from terminal/lost life state
|
| 93 |
+
obs, _, _, _ = self.env.step(0)
|
| 94 |
+
self.lives = self.env.unwrapped.ale.lives()
|
| 95 |
+
return obs
|
| 96 |
+
|
| 97 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 98 |
+
def __init__(self, env, skip=4):
|
| 99 |
+
"""Return only every `skip`-th frame"""
|
| 100 |
+
gym.Wrapper.__init__(self, env)
|
| 101 |
+
# most recent raw observations (for max pooling across time steps)
|
| 102 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
| 103 |
+
self._skip = skip
|
| 104 |
+
|
| 105 |
+
def step(self, action):
|
| 106 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 107 |
+
total_reward = 0.0
|
| 108 |
+
done = None
|
| 109 |
+
for i in range(self._skip):
|
| 110 |
+
obs, reward, done, info = self.env.step(action)
|
| 111 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
| 112 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
| 113 |
+
total_reward += reward
|
| 114 |
+
if done:
|
| 115 |
+
break
|
| 116 |
+
# Note that the observation on the done=True frame
|
| 117 |
+
# doesn't matter
|
| 118 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 119 |
+
|
| 120 |
+
return max_frame, total_reward, done, info
|
| 121 |
+
|
| 122 |
+
def reset(self, **kwargs):
|
| 123 |
+
return self.env.reset(**kwargs)
|
| 124 |
+
|
| 125 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 126 |
+
def __init__(self, env):
|
| 127 |
+
gym.RewardWrapper.__init__(self, env)
|
| 128 |
+
|
| 129 |
+
def reward(self, reward):
|
| 130 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 131 |
+
return np.sign(reward)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 135 |
+
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
|
| 136 |
+
"""
|
| 137 |
+
Warp frames to 84x84 as done in the Nature paper and later work.
|
| 138 |
+
|
| 139 |
+
If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
|
| 140 |
+
observation should be warped.
|
| 141 |
+
"""
|
| 142 |
+
super().__init__(env)
|
| 143 |
+
self._width = width
|
| 144 |
+
self._height = height
|
| 145 |
+
self._grayscale = grayscale
|
| 146 |
+
self._key = dict_space_key
|
| 147 |
+
if self._grayscale:
|
| 148 |
+
num_colors = 1
|
| 149 |
+
else:
|
| 150 |
+
num_colors = 3
|
| 151 |
+
|
| 152 |
+
new_space = gym.spaces.Box(
|
| 153 |
+
low=0,
|
| 154 |
+
high=255,
|
| 155 |
+
shape=(self._height, self._width, num_colors),
|
| 156 |
+
dtype=np.uint8,
|
| 157 |
+
)
|
| 158 |
+
if self._key is None:
|
| 159 |
+
original_space = self.observation_space
|
| 160 |
+
self.observation_space = new_space
|
| 161 |
+
else:
|
| 162 |
+
original_space = self.observation_space.spaces[self._key]
|
| 163 |
+
self.observation_space.spaces[self._key] = new_space
|
| 164 |
+
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
|
| 165 |
+
|
| 166 |
+
def observation(self, obs):
|
| 167 |
+
if self._key is None:
|
| 168 |
+
frame = obs
|
| 169 |
+
else:
|
| 170 |
+
frame = obs[self._key]
|
| 171 |
+
|
| 172 |
+
if self._grayscale:
|
| 173 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 174 |
+
frame = cv2.resize(
|
| 175 |
+
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
|
| 176 |
+
)
|
| 177 |
+
if self._grayscale:
|
| 178 |
+
frame = np.expand_dims(frame, -1)
|
| 179 |
+
|
| 180 |
+
if self._key is None:
|
| 181 |
+
obs = frame
|
| 182 |
+
else:
|
| 183 |
+
obs = obs.copy()
|
| 184 |
+
obs[self._key] = frame
|
| 185 |
+
return obs
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class FrameStack(gym.Wrapper):
|
| 189 |
+
def __init__(self, env, k):
|
| 190 |
+
"""Stack k last frames.
|
| 191 |
+
|
| 192 |
+
Returns lazy array, which is much more memory efficient.
|
| 193 |
+
|
| 194 |
+
See Also
|
| 195 |
+
--------
|
| 196 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 197 |
+
"""
|
| 198 |
+
gym.Wrapper.__init__(self, env)
|
| 199 |
+
self.k = k
|
| 200 |
+
self.frames = deque([], maxlen=k)
|
| 201 |
+
shp = env.observation_space.shape
|
| 202 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
|
| 203 |
+
|
| 204 |
+
def reset(self):
|
| 205 |
+
ob = self.env.reset()
|
| 206 |
+
for _ in range(self.k):
|
| 207 |
+
self.frames.append(ob)
|
| 208 |
+
return self._get_ob()
|
| 209 |
+
|
| 210 |
+
def step(self, action):
|
| 211 |
+
ob, reward, done, info = self.env.step(action)
|
| 212 |
+
self.frames.append(ob)
|
| 213 |
+
return self._get_ob(), reward, done, info
|
| 214 |
+
|
| 215 |
+
def _get_ob(self):
|
| 216 |
+
assert len(self.frames) == self.k
|
| 217 |
+
return LazyFrames(list(self.frames))
|
| 218 |
+
|
| 219 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
| 220 |
+
def __init__(self, env):
|
| 221 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 222 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
| 223 |
+
|
| 224 |
+
def observation(self, observation):
|
| 225 |
+
# careful! This undoes the memory optimization, use
|
| 226 |
+
# with smaller replay buffers only.
|
| 227 |
+
return np.array(observation).astype(np.float32) / 255.0
|
| 228 |
+
|
| 229 |
+
class LazyFrames(object):
|
| 230 |
+
def __init__(self, frames):
|
| 231 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 232 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 233 |
+
buffers.
|
| 234 |
+
|
| 235 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 236 |
+
|
| 237 |
+
You'd not believe how complex the previous solution was."""
|
| 238 |
+
self._frames = frames
|
| 239 |
+
self._out = None
|
| 240 |
+
|
| 241 |
+
def _force(self):
|
| 242 |
+
if self._out is None:
|
| 243 |
+
self._out = np.concatenate(self._frames, axis=-1)
|
| 244 |
+
self._frames = None
|
| 245 |
+
return self._out
|
| 246 |
+
|
| 247 |
+
def __array__(self, dtype=None):
|
| 248 |
+
out = self._force()
|
| 249 |
+
if dtype is not None:
|
| 250 |
+
out = out.astype(dtype)
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
def __len__(self):
|
| 254 |
+
return len(self._force())
|
| 255 |
+
|
| 256 |
+
def __getitem__(self, i):
|
| 257 |
+
return self._force()[i]
|
| 258 |
+
|
| 259 |
+
def count(self):
|
| 260 |
+
frames = self._force()
|
| 261 |
+
return frames.shape[frames.ndim - 1]
|
| 262 |
+
|
| 263 |
+
def frame(self, i):
|
| 264 |
+
return self._force()[..., i]
|
| 265 |
+
|
| 266 |
+
def make_atari(env_id, max_episode_steps=None):
|
| 267 |
+
env = gym.make(env_id)
|
| 268 |
+
assert 'NoFrameskip' in env.spec.id
|
| 269 |
+
env = NoopResetEnv(env, noop_max=30)
|
| 270 |
+
env = MaxAndSkipEnv(env, skip=4)
|
| 271 |
+
if max_episode_steps is not None:
|
| 272 |
+
env = TimeLimit(env, max_episode_steps=max_episode_steps)
|
| 273 |
+
return env
|
| 274 |
+
|
| 275 |
+
def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
|
| 276 |
+
"""Configure environment for DeepMind-style Atari.
|
| 277 |
+
"""
|
| 278 |
+
if episode_life:
|
| 279 |
+
env = EpisodicLifeEnv(env)
|
| 280 |
+
if 'FIRE' in env.unwrapped.get_action_meanings():
|
| 281 |
+
env = FireResetEnv(env)
|
| 282 |
+
env = WarpFrame(env)
|
| 283 |
+
if scale:
|
| 284 |
+
env = ScaledFloatFrame(env)
|
| 285 |
+
if clip_rewards:
|
| 286 |
+
env = ClipRewardEnv(env)
|
| 287 |
+
if frame_stack:
|
| 288 |
+
env = FrameStack(env, 4)
|
| 289 |
+
return env
|
| 290 |
+
|
baselines/common/cg.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
|
| 3 |
+
"""
|
| 4 |
+
Demmel p 312
|
| 5 |
+
"""
|
| 6 |
+
p = b.copy()
|
| 7 |
+
r = b.copy()
|
| 8 |
+
x = np.zeros_like(b)
|
| 9 |
+
rdotr = r.dot(r)
|
| 10 |
+
|
| 11 |
+
fmtstr = "%10i %10.3g %10.3g"
|
| 12 |
+
titlestr = "%10s %10s %10s"
|
| 13 |
+
if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
|
| 14 |
+
|
| 15 |
+
for i in range(cg_iters):
|
| 16 |
+
if callback is not None:
|
| 17 |
+
callback(x)
|
| 18 |
+
if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
|
| 19 |
+
z = f_Ax(p)
|
| 20 |
+
v = rdotr / p.dot(z)
|
| 21 |
+
x += v*p
|
| 22 |
+
r -= v*z
|
| 23 |
+
newrdotr = r.dot(r)
|
| 24 |
+
mu = newrdotr/rdotr
|
| 25 |
+
p = r + mu*p
|
| 26 |
+
|
| 27 |
+
rdotr = newrdotr
|
| 28 |
+
if rdotr < residual_tol:
|
| 29 |
+
break
|
| 30 |
+
|
| 31 |
+
if callback is not None:
|
| 32 |
+
callback(x)
|
| 33 |
+
if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
|
| 34 |
+
return x
|
baselines/common/cmd_util.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for scripts like run_atari.py.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
try:
|
| 7 |
+
from mpi4py import MPI
|
| 8 |
+
except ImportError:
|
| 9 |
+
MPI = None
|
| 10 |
+
|
| 11 |
+
import gym
|
| 12 |
+
from gym.wrappers import FlattenObservation, FilterObservation
|
| 13 |
+
from baselines import logger
|
| 14 |
+
from baselines.bench import Monitor
|
| 15 |
+
from baselines.common import set_global_seeds
|
| 16 |
+
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
| 17 |
+
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
| 18 |
+
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
| 19 |
+
from baselines.common import retro_wrappers
|
| 20 |
+
from baselines.common.wrappers import ClipActionsWrapper
|
| 21 |
+
|
| 22 |
+
def make_vec_env(env_id, env_type, num_env, seed,
|
| 23 |
+
wrapper_kwargs=None,
|
| 24 |
+
env_kwargs=None,
|
| 25 |
+
start_index=0,
|
| 26 |
+
reward_scale=1.0,
|
| 27 |
+
flatten_dict_observations=True,
|
| 28 |
+
gamestate=None,
|
| 29 |
+
initializer=None,
|
| 30 |
+
force_dummy=False):
|
| 31 |
+
"""
|
| 32 |
+
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
| 33 |
+
"""
|
| 34 |
+
wrapper_kwargs = wrapper_kwargs or {}
|
| 35 |
+
env_kwargs = env_kwargs or {}
|
| 36 |
+
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
| 37 |
+
seed = seed + 10000 * mpi_rank if seed is not None else None
|
| 38 |
+
logger_dir = logger.get_dir()
|
| 39 |
+
def make_thunk(rank, initializer=None):
|
| 40 |
+
return lambda: make_env(
|
| 41 |
+
env_id=env_id,
|
| 42 |
+
env_type=env_type,
|
| 43 |
+
mpi_rank=mpi_rank,
|
| 44 |
+
subrank=rank,
|
| 45 |
+
seed=seed,
|
| 46 |
+
reward_scale=reward_scale,
|
| 47 |
+
gamestate=gamestate,
|
| 48 |
+
flatten_dict_observations=flatten_dict_observations,
|
| 49 |
+
wrapper_kwargs=wrapper_kwargs,
|
| 50 |
+
env_kwargs=env_kwargs,
|
| 51 |
+
logger_dir=logger_dir,
|
| 52 |
+
initializer=initializer
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
set_global_seeds(seed)
|
| 56 |
+
if not force_dummy and num_env > 1:
|
| 57 |
+
return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
|
| 58 |
+
else:
|
| 59 |
+
return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
|
| 63 |
+
if initializer is not None:
|
| 64 |
+
initializer(mpi_rank=mpi_rank, subrank=subrank)
|
| 65 |
+
|
| 66 |
+
wrapper_kwargs = wrapper_kwargs or {}
|
| 67 |
+
env_kwargs = env_kwargs or {}
|
| 68 |
+
if ':' in env_id:
|
| 69 |
+
import re
|
| 70 |
+
import importlib
|
| 71 |
+
module_name = re.sub(':.*','',env_id)
|
| 72 |
+
env_id = re.sub('.*:', '', env_id)
|
| 73 |
+
importlib.import_module(module_name)
|
| 74 |
+
if env_type == 'atari':
|
| 75 |
+
env = make_atari(env_id)
|
| 76 |
+
elif env_type == 'retro':
|
| 77 |
+
import retro
|
| 78 |
+
gamestate = gamestate or retro.State.DEFAULT
|
| 79 |
+
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
| 80 |
+
else:
|
| 81 |
+
env = gym.make(env_id, **env_kwargs)
|
| 82 |
+
|
| 83 |
+
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
| 84 |
+
env = FlattenObservation(env)
|
| 85 |
+
|
| 86 |
+
env.seed(seed + subrank if seed is not None else None)
|
| 87 |
+
env = Monitor(env,
|
| 88 |
+
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
| 89 |
+
allow_early_resets=True)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if env_type == 'atari':
|
| 93 |
+
env = wrap_deepmind(env, **wrapper_kwargs)
|
| 94 |
+
elif env_type == 'retro':
|
| 95 |
+
if 'frame_stack' not in wrapper_kwargs:
|
| 96 |
+
wrapper_kwargs['frame_stack'] = 1
|
| 97 |
+
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
|
| 98 |
+
|
| 99 |
+
if isinstance(env.action_space, gym.spaces.Box):
|
| 100 |
+
env = ClipActionsWrapper(env)
|
| 101 |
+
|
| 102 |
+
if reward_scale != 1:
|
| 103 |
+
env = retro_wrappers.RewardScaler(env, reward_scale)
|
| 104 |
+
|
| 105 |
+
return env
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
| 109 |
+
"""
|
| 110 |
+
Create a wrapped, monitored gym.Env for MuJoCo.
|
| 111 |
+
"""
|
| 112 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
| 113 |
+
myseed = seed + 1000 * rank if seed is not None else None
|
| 114 |
+
set_global_seeds(myseed)
|
| 115 |
+
env = gym.make(env_id)
|
| 116 |
+
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
| 117 |
+
env = Monitor(env, logger_path, allow_early_resets=True)
|
| 118 |
+
env.seed(seed)
|
| 119 |
+
if reward_scale != 1.0:
|
| 120 |
+
from baselines.common.retro_wrappers import RewardScaler
|
| 121 |
+
env = RewardScaler(env, reward_scale)
|
| 122 |
+
return env
|
| 123 |
+
|
| 124 |
+
def make_robotics_env(env_id, seed, rank=0):
|
| 125 |
+
"""
|
| 126 |
+
Create a wrapped, monitored gym.Env for MuJoCo.
|
| 127 |
+
"""
|
| 128 |
+
set_global_seeds(seed)
|
| 129 |
+
env = gym.make(env_id)
|
| 130 |
+
env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
|
| 131 |
+
env = Monitor(
|
| 132 |
+
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
|
| 133 |
+
info_keywords=('is_success',))
|
| 134 |
+
env.seed(seed)
|
| 135 |
+
return env
|
| 136 |
+
|
| 137 |
+
def arg_parser():
|
| 138 |
+
"""
|
| 139 |
+
Create an empty argparse.ArgumentParser.
|
| 140 |
+
"""
|
| 141 |
+
import argparse
|
| 142 |
+
return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 143 |
+
|
| 144 |
+
def atari_arg_parser():
|
| 145 |
+
"""
|
| 146 |
+
Create an argparse.ArgumentParser for run_atari.py.
|
| 147 |
+
"""
|
| 148 |
+
print('Obsolete - use common_arg_parser instead')
|
| 149 |
+
return common_arg_parser()
|
| 150 |
+
|
| 151 |
+
def mujoco_arg_parser():
|
| 152 |
+
print('Obsolete - use common_arg_parser instead')
|
| 153 |
+
return common_arg_parser()
|
| 154 |
+
|
| 155 |
+
def common_arg_parser():
|
| 156 |
+
"""
|
| 157 |
+
Create an argparse.ArgumentParser for run_mujoco.py.
|
| 158 |
+
"""
|
| 159 |
+
parser = arg_parser()
|
| 160 |
+
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
| 161 |
+
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
| 162 |
+
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
| 163 |
+
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
| 164 |
+
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
| 165 |
+
parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
|
| 166 |
+
parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
|
| 167 |
+
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
|
| 168 |
+
parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
|
| 169 |
+
parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
|
| 170 |
+
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
| 171 |
+
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
| 172 |
+
parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
|
| 173 |
+
parser.add_argument('--play', default=False, action='store_true')
|
| 174 |
+
return parser
|
| 175 |
+
|
| 176 |
+
def robotics_arg_parser():
|
| 177 |
+
"""
|
| 178 |
+
Create an argparse.ArgumentParser for run_mujoco.py.
|
| 179 |
+
"""
|
| 180 |
+
parser = arg_parser()
|
| 181 |
+
parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v0')
|
| 182 |
+
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
| 183 |
+
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
|
| 184 |
+
return parser
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def parse_unknown_args(args):
|
| 188 |
+
"""
|
| 189 |
+
Parse arguments not consumed by arg parser into a dictionary
|
| 190 |
+
"""
|
| 191 |
+
retval = {}
|
| 192 |
+
preceded_by_key = False
|
| 193 |
+
for arg in args:
|
| 194 |
+
if arg.startswith('--'):
|
| 195 |
+
if '=' in arg:
|
| 196 |
+
key = arg.split('=')[0][2:]
|
| 197 |
+
value = arg.split('=')[1]
|
| 198 |
+
retval[key] = value
|
| 199 |
+
else:
|
| 200 |
+
key = arg[2:]
|
| 201 |
+
preceded_by_key = True
|
| 202 |
+
elif preceded_by_key:
|
| 203 |
+
retval[key] = arg
|
| 204 |
+
preceded_by_key = False
|
| 205 |
+
|
| 206 |
+
return retval
|
baselines/common/console_util.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import shlex
|
| 6 |
+
import subprocess
|
| 7 |
+
|
| 8 |
+
# ================================================================
|
| 9 |
+
# Misc
|
| 10 |
+
# ================================================================
|
| 11 |
+
|
| 12 |
+
def fmt_row(width, row, header=False):
|
| 13 |
+
out = " | ".join(fmt_item(x, width) for x in row)
|
| 14 |
+
if header: out = out + "\n" + "-"*len(out)
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
def fmt_item(x, l):
|
| 18 |
+
if isinstance(x, np.ndarray):
|
| 19 |
+
assert x.ndim==0
|
| 20 |
+
x = x.item()
|
| 21 |
+
if isinstance(x, (float, np.float32, np.float64)):
|
| 22 |
+
v = abs(x)
|
| 23 |
+
if (v < 1e-4 or v > 1e+4) and v > 0:
|
| 24 |
+
rep = "%7.2e" % x
|
| 25 |
+
else:
|
| 26 |
+
rep = "%7.5f" % x
|
| 27 |
+
else: rep = str(x)
|
| 28 |
+
return " "*(l - len(rep)) + rep
|
| 29 |
+
|
| 30 |
+
color2num = dict(
|
| 31 |
+
gray=30,
|
| 32 |
+
red=31,
|
| 33 |
+
green=32,
|
| 34 |
+
yellow=33,
|
| 35 |
+
blue=34,
|
| 36 |
+
magenta=35,
|
| 37 |
+
cyan=36,
|
| 38 |
+
white=37,
|
| 39 |
+
crimson=38
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def colorize(string, color='green', bold=False, highlight=False):
|
| 43 |
+
attr = []
|
| 44 |
+
num = color2num[color]
|
| 45 |
+
if highlight: num += 10
|
| 46 |
+
attr.append(str(num))
|
| 47 |
+
if bold: attr.append('1')
|
| 48 |
+
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
|
| 49 |
+
|
| 50 |
+
def print_cmd(cmd, dry=False):
|
| 51 |
+
if isinstance(cmd, str): # for shell=True
|
| 52 |
+
pass
|
| 53 |
+
else:
|
| 54 |
+
cmd = ' '.join(shlex.quote(arg) for arg in cmd)
|
| 55 |
+
print(colorize(('CMD: ' if not dry else 'DRY: ') + cmd))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_git_commit(cwd=None):
|
| 59 |
+
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=cwd).decode('utf8')
|
| 60 |
+
|
| 61 |
+
def get_git_commit_message(cwd=None):
|
| 62 |
+
return subprocess.check_output(['git', 'show', '-s', '--format=%B', 'HEAD'], cwd=cwd).decode('utf8')
|
| 63 |
+
|
| 64 |
+
def ccap(cmd, dry=False, env=None, **kwargs):
|
| 65 |
+
print_cmd(cmd, dry)
|
| 66 |
+
if not dry:
|
| 67 |
+
subprocess.check_call(cmd, env=env, **kwargs)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
MESSAGE_DEPTH = 0
|
| 71 |
+
|
| 72 |
+
@contextmanager
|
| 73 |
+
def timed(msg):
|
| 74 |
+
global MESSAGE_DEPTH #pylint: disable=W0603
|
| 75 |
+
print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
|
| 76 |
+
tstart = time.time()
|
| 77 |
+
MESSAGE_DEPTH += 1
|
| 78 |
+
yield
|
| 79 |
+
MESSAGE_DEPTH -= 1
|
| 80 |
+
print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))
|
baselines/common/dataset.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Dataset(object):
|
| 4 |
+
def __init__(self, data_map, deterministic=False, shuffle=True):
|
| 5 |
+
self.data_map = data_map
|
| 6 |
+
self.deterministic = deterministic
|
| 7 |
+
self.enable_shuffle = shuffle
|
| 8 |
+
self.n = next(iter(data_map.values())).shape[0]
|
| 9 |
+
self._next_id = 0
|
| 10 |
+
self.shuffle()
|
| 11 |
+
|
| 12 |
+
def shuffle(self):
|
| 13 |
+
if self.deterministic:
|
| 14 |
+
return
|
| 15 |
+
perm = np.arange(self.n)
|
| 16 |
+
np.random.shuffle(perm)
|
| 17 |
+
|
| 18 |
+
for key in self.data_map:
|
| 19 |
+
self.data_map[key] = self.data_map[key][perm]
|
| 20 |
+
|
| 21 |
+
self._next_id = 0
|
| 22 |
+
|
| 23 |
+
def next_batch(self, batch_size):
|
| 24 |
+
if self._next_id >= self.n and self.enable_shuffle:
|
| 25 |
+
self.shuffle()
|
| 26 |
+
|
| 27 |
+
cur_id = self._next_id
|
| 28 |
+
cur_batch_size = min(batch_size, self.n - self._next_id)
|
| 29 |
+
self._next_id += cur_batch_size
|
| 30 |
+
|
| 31 |
+
data_map = dict()
|
| 32 |
+
for key in self.data_map:
|
| 33 |
+
data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
|
| 34 |
+
return data_map
|
| 35 |
+
|
| 36 |
+
def iterate_once(self, batch_size):
|
| 37 |
+
if self.enable_shuffle: self.shuffle()
|
| 38 |
+
|
| 39 |
+
while self._next_id <= self.n - batch_size:
|
| 40 |
+
yield self.next_batch(batch_size)
|
| 41 |
+
self._next_id = 0
|
| 42 |
+
|
| 43 |
+
def subset(self, num_elements, deterministic=True):
|
| 44 |
+
data_map = dict()
|
| 45 |
+
for key in self.data_map:
|
| 46 |
+
data_map[key] = self.data_map[key][:num_elements]
|
| 47 |
+
return Dataset(data_map, deterministic)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
|
| 51 |
+
assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
|
| 52 |
+
arrays = tuple(map(np.asarray, arrays))
|
| 53 |
+
n = arrays[0].shape[0]
|
| 54 |
+
assert all(a.shape[0] == n for a in arrays[1:])
|
| 55 |
+
inds = np.arange(n)
|
| 56 |
+
if shuffle: np.random.shuffle(inds)
|
| 57 |
+
sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
|
| 58 |
+
for batch_inds in np.array_split(inds, sections):
|
| 59 |
+
if include_final_partial_batch or len(batch_inds) == batch_size:
|
| 60 |
+
yield tuple(a[batch_inds] for a in arrays)
|
baselines/common/distributions.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
import baselines.common.tf_util as U
|
| 4 |
+
from baselines.a2c.utils import fc
|
| 5 |
+
from tensorflow.python.ops import math_ops
|
| 6 |
+
|
| 7 |
+
class Pd(object):
|
| 8 |
+
"""
|
| 9 |
+
A particular probability distribution
|
| 10 |
+
"""
|
| 11 |
+
def flatparam(self):
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
def mode(self):
|
| 14 |
+
raise NotImplementedError
|
| 15 |
+
def neglogp(self, x):
|
| 16 |
+
# Usually it's easier to define the negative logprob
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
def kl(self, other):
|
| 19 |
+
raise NotImplementedError
|
| 20 |
+
def entropy(self):
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
def sample(self):
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
def logp(self, x):
|
| 25 |
+
return - self.neglogp(x)
|
| 26 |
+
def get_shape(self):
|
| 27 |
+
return self.flatparam().shape
|
| 28 |
+
@property
|
| 29 |
+
def shape(self):
|
| 30 |
+
return self.get_shape()
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
return self.__class__(self.flatparam()[idx])
|
| 33 |
+
|
| 34 |
+
class PdType(object):
|
| 35 |
+
"""
|
| 36 |
+
Parametrized family of probability distributions
|
| 37 |
+
"""
|
| 38 |
+
def pdclass(self):
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
def pdfromflat(self, flat):
|
| 41 |
+
return self.pdclass()(flat)
|
| 42 |
+
def pdfromlatent(self, latent_vector, init_scale, init_bias):
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
def param_shape(self):
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
def sample_shape(self):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
def sample_dtype(self):
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
def param_placeholder(self, prepend_shape, name=None):
|
| 52 |
+
return tf.compat.v1.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
|
| 53 |
+
def sample_placeholder(self, prepend_shape, name=None):
|
| 54 |
+
return tf.compat.v1.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
|
| 55 |
+
|
| 56 |
+
def __eq__(self, other):
|
| 57 |
+
return (type(self) == type(other)) and (self.__dict__ == other.__dict__)
|
| 58 |
+
|
| 59 |
+
class CategoricalPdType(PdType):
|
| 60 |
+
def __init__(self, ncat):
|
| 61 |
+
self.ncat = ncat
|
| 62 |
+
def pdclass(self):
|
| 63 |
+
return CategoricalPd
|
| 64 |
+
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
| 65 |
+
pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
| 66 |
+
return self.pdfromflat(pdparam), pdparam
|
| 67 |
+
|
| 68 |
+
def param_shape(self):
|
| 69 |
+
return [self.ncat]
|
| 70 |
+
def sample_shape(self):
|
| 71 |
+
return []
|
| 72 |
+
def sample_dtype(self):
|
| 73 |
+
return tf.int32
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MultiCategoricalPdType(PdType):
|
| 77 |
+
def __init__(self, nvec):
|
| 78 |
+
self.ncats = nvec.astype('int32')
|
| 79 |
+
assert (self.ncats > 0).all()
|
| 80 |
+
def pdclass(self):
|
| 81 |
+
return MultiCategoricalPd
|
| 82 |
+
def pdfromflat(self, flat):
|
| 83 |
+
return MultiCategoricalPd(self.ncats, flat)
|
| 84 |
+
|
| 85 |
+
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
|
| 86 |
+
pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
|
| 87 |
+
return self.pdfromflat(pdparam), pdparam
|
| 88 |
+
|
| 89 |
+
def param_shape(self):
|
| 90 |
+
return [sum(self.ncats)]
|
| 91 |
+
def sample_shape(self):
|
| 92 |
+
return [len(self.ncats)]
|
| 93 |
+
def sample_dtype(self):
|
| 94 |
+
return tf.int32
|
| 95 |
+
|
| 96 |
+
class DiagGaussianPdType(PdType):
|
| 97 |
+
def __init__(self, size):
|
| 98 |
+
self.size = size
|
| 99 |
+
def pdclass(self):
|
| 100 |
+
return DiagGaussianPd
|
| 101 |
+
|
| 102 |
+
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
| 103 |
+
mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
| 104 |
+
logstd = tf.compat.v1.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.compat.v1.zeros_initializer())
|
| 105 |
+
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
|
| 106 |
+
return self.pdfromflat(pdparam), mean
|
| 107 |
+
|
| 108 |
+
def param_shape(self):
|
| 109 |
+
return [2*self.size]
|
| 110 |
+
def sample_shape(self):
|
| 111 |
+
return [self.size]
|
| 112 |
+
def sample_dtype(self):
|
| 113 |
+
return tf.float32
|
| 114 |
+
|
| 115 |
+
class BernoulliPdType(PdType):
|
| 116 |
+
def __init__(self, size):
|
| 117 |
+
self.size = size
|
| 118 |
+
def pdclass(self):
|
| 119 |
+
return BernoulliPd
|
| 120 |
+
def param_shape(self):
|
| 121 |
+
return [self.size]
|
| 122 |
+
def sample_shape(self):
|
| 123 |
+
return [self.size]
|
| 124 |
+
def sample_dtype(self):
|
| 125 |
+
return tf.int32
|
| 126 |
+
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
| 127 |
+
pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
| 128 |
+
return self.pdfromflat(pdparam), pdparam
|
| 129 |
+
|
| 130 |
+
# WRONG SECOND DERIVATIVES
|
| 131 |
+
# class CategoricalPd(Pd):
|
| 132 |
+
# def __init__(self, logits):
|
| 133 |
+
# self.logits = logits
|
| 134 |
+
# self.ps = tf.nn.softmax(logits)
|
| 135 |
+
# @classmethod
|
| 136 |
+
# def fromflat(cls, flat):
|
| 137 |
+
# return cls(flat)
|
| 138 |
+
# def flatparam(self):
|
| 139 |
+
# return self.logits
|
| 140 |
+
# def mode(self):
|
| 141 |
+
# return U.argmax(self.logits, axis=-1)
|
| 142 |
+
# def logp(self, x):
|
| 143 |
+
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
|
| 144 |
+
# def kl(self, other):
|
| 145 |
+
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
|
| 146 |
+
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
| 147 |
+
# def entropy(self):
|
| 148 |
+
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
| 149 |
+
# def sample(self):
|
| 150 |
+
# u = tf.random_uniform(tf.shape(self.logits))
|
| 151 |
+
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
|
| 152 |
+
|
| 153 |
+
class CategoricalPd(Pd):
|
| 154 |
+
def __init__(self, logits):
|
| 155 |
+
self.logits = logits
|
| 156 |
+
def flatparam(self):
|
| 157 |
+
return self.logits
|
| 158 |
+
def mode(self):
|
| 159 |
+
return tf.argmax(input=self.logits, axis=-1)
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def mean(self):
|
| 163 |
+
return tf.nn.softmax(self.logits)
|
| 164 |
+
def neglogp(self, x):
|
| 165 |
+
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
|
| 166 |
+
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
|
| 167 |
+
# the implementation does not allow second-order derivatives...
|
| 168 |
+
if x.dtype in {tf.uint8, tf.int32, tf.int64}:
|
| 169 |
+
# one-hot encoding
|
| 170 |
+
x_shape_list = x.shape.as_list()
|
| 171 |
+
logits_shape_list = self.logits.get_shape().as_list()[:-1]
|
| 172 |
+
for xs, ls in zip(x_shape_list, logits_shape_list):
|
| 173 |
+
if xs is not None and ls is not None:
|
| 174 |
+
assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
|
| 175 |
+
|
| 176 |
+
x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
|
| 177 |
+
else:
|
| 178 |
+
# already encoded
|
| 179 |
+
assert x.shape.as_list() == self.logits.shape.as_list()
|
| 180 |
+
|
| 181 |
+
return tf.nn.softmax_cross_entropy_with_logits(
|
| 182 |
+
logits=self.logits,
|
| 183 |
+
labels=x)
|
| 184 |
+
def kl(self, other):
|
| 185 |
+
a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True)
|
| 186 |
+
a1 = other.logits - tf.reduce_max(input_tensor=other.logits, axis=-1, keepdims=True)
|
| 187 |
+
ea0 = tf.exp(a0)
|
| 188 |
+
ea1 = tf.exp(a1)
|
| 189 |
+
z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True)
|
| 190 |
+
z1 = tf.reduce_sum(input_tensor=ea1, axis=-1, keepdims=True)
|
| 191 |
+
p0 = ea0 / z0
|
| 192 |
+
return tf.reduce_sum(input_tensor=p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=-1)
|
| 193 |
+
def entropy(self):
|
| 194 |
+
a0 = self.logits - tf.reduce_max(input_tensor=self.logits, axis=-1, keepdims=True)
|
| 195 |
+
ea0 = tf.exp(a0)
|
| 196 |
+
z0 = tf.reduce_sum(input_tensor=ea0, axis=-1, keepdims=True)
|
| 197 |
+
p0 = ea0 / z0
|
| 198 |
+
return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=-1)
|
| 199 |
+
def sample(self):
|
| 200 |
+
u = tf.random.uniform(tf.shape(input=self.logits), dtype=self.logits.dtype)
|
| 201 |
+
return tf.argmax(input=self.logits - tf.math.log(-tf.math.log(u)), axis=-1)
|
| 202 |
+
@classmethod
|
| 203 |
+
def fromflat(cls, flat):
|
| 204 |
+
return cls(flat)
|
| 205 |
+
|
| 206 |
+
class MultiCategoricalPd(Pd):
|
| 207 |
+
def __init__(self, nvec, flat):
|
| 208 |
+
self.flat = flat
|
| 209 |
+
self.categoricals = list(map(CategoricalPd,
|
| 210 |
+
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
|
| 211 |
+
def flatparam(self):
|
| 212 |
+
return self.flat
|
| 213 |
+
def mode(self):
|
| 214 |
+
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
|
| 215 |
+
def neglogp(self, x):
|
| 216 |
+
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
|
| 217 |
+
def kl(self, other):
|
| 218 |
+
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
|
| 219 |
+
def entropy(self):
|
| 220 |
+
return tf.add_n([p.entropy() for p in self.categoricals])
|
| 221 |
+
def sample(self):
|
| 222 |
+
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
|
| 223 |
+
@classmethod
|
| 224 |
+
def fromflat(cls, flat):
|
| 225 |
+
raise NotImplementedError
|
| 226 |
+
|
| 227 |
+
class DiagGaussianPd(Pd):
|
| 228 |
+
def __init__(self, flat):
|
| 229 |
+
self.flat = flat
|
| 230 |
+
mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
|
| 231 |
+
self.mean = mean
|
| 232 |
+
self.logstd = logstd
|
| 233 |
+
self.std = tf.exp(logstd)
|
| 234 |
+
def flatparam(self):
|
| 235 |
+
return self.flat
|
| 236 |
+
def mode(self):
|
| 237 |
+
return self.mean
|
| 238 |
+
def neglogp(self, x):
|
| 239 |
+
return 0.5 * tf.reduce_sum(input_tensor=tf.square((x - self.mean) / self.std), axis=-1) \
|
| 240 |
+
+ 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(input=x)[-1], dtype=tf.float32) \
|
| 241 |
+
+ tf.reduce_sum(input_tensor=self.logstd, axis=-1)
|
| 242 |
+
def kl(self, other):
|
| 243 |
+
assert isinstance(other, DiagGaussianPd)
|
| 244 |
+
return tf.reduce_sum(input_tensor=other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
|
| 245 |
+
def entropy(self):
|
| 246 |
+
return tf.reduce_sum(input_tensor=self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
|
| 247 |
+
def sample(self):
|
| 248 |
+
return self.mean + self.std * tf.random.normal(tf.shape(input=self.mean))
|
| 249 |
+
@classmethod
|
| 250 |
+
def fromflat(cls, flat):
|
| 251 |
+
return cls(flat)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class BernoulliPd(Pd):
|
| 255 |
+
def __init__(self, logits):
|
| 256 |
+
self.logits = logits
|
| 257 |
+
self.ps = tf.sigmoid(logits)
|
| 258 |
+
def flatparam(self):
|
| 259 |
+
return self.logits
|
| 260 |
+
@property
|
| 261 |
+
def mean(self):
|
| 262 |
+
return self.ps
|
| 263 |
+
def mode(self):
|
| 264 |
+
return tf.round(self.ps)
|
| 265 |
+
def neglogp(self, x):
|
| 266 |
+
return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.cast(x, dtype=tf.float32)), axis=-1)
|
| 267 |
+
def kl(self, other):
|
| 268 |
+
return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
|
| 269 |
+
def entropy(self):
|
| 270 |
+
return tf.reduce_sum(input_tensor=tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
|
| 271 |
+
def sample(self):
|
| 272 |
+
u = tf.random.uniform(tf.shape(input=self.ps))
|
| 273 |
+
return tf.cast(math_ops.less(u, self.ps), dtype=tf.float32)
|
| 274 |
+
@classmethod
|
| 275 |
+
def fromflat(cls, flat):
|
| 276 |
+
return cls(flat)
|
| 277 |
+
|
| 278 |
+
def make_pdtype(ac_space):
|
| 279 |
+
from gym import spaces
|
| 280 |
+
if isinstance(ac_space, spaces.Box):
|
| 281 |
+
assert len(ac_space.shape) == 1
|
| 282 |
+
return DiagGaussianPdType(ac_space.shape[0])
|
| 283 |
+
elif isinstance(ac_space, spaces.Discrete):
|
| 284 |
+
return CategoricalPdType(ac_space.n)
|
| 285 |
+
elif isinstance(ac_space, spaces.MultiDiscrete):
|
| 286 |
+
return MultiCategoricalPdType(ac_space.nvec)
|
| 287 |
+
elif isinstance(ac_space, spaces.MultiBinary):
|
| 288 |
+
return BernoulliPdType(ac_space.n)
|
| 289 |
+
else:
|
| 290 |
+
raise NotImplementedError
|
| 291 |
+
|
| 292 |
+
def shape_el(v, i):
|
| 293 |
+
maybe = v.get_shape()[i]
|
| 294 |
+
if maybe is not None:
|
| 295 |
+
return maybe
|
| 296 |
+
else:
|
| 297 |
+
return tf.shape(input=v)[i]
|
| 298 |
+
|
| 299 |
+
@U.in_session
|
| 300 |
+
def test_probtypes():
|
| 301 |
+
np.random.seed(0)
|
| 302 |
+
|
| 303 |
+
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
|
| 304 |
+
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
|
| 305 |
+
validate_probtype(diag_gauss, pdparam_diag_gauss)
|
| 306 |
+
|
| 307 |
+
pdparam_categorical = np.array([-.2, .3, .5])
|
| 308 |
+
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
|
| 309 |
+
validate_probtype(categorical, pdparam_categorical)
|
| 310 |
+
|
| 311 |
+
nvec = [1,2,3]
|
| 312 |
+
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
|
| 313 |
+
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
|
| 314 |
+
validate_probtype(multicategorical, pdparam_multicategorical)
|
| 315 |
+
|
| 316 |
+
pdparam_bernoulli = np.array([-.2, .3, .5])
|
| 317 |
+
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
|
| 318 |
+
validate_probtype(bernoulli, pdparam_bernoulli)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def validate_probtype(probtype, pdparam):
|
| 322 |
+
N = 100000
|
| 323 |
+
# Check to see if mean negative log likelihood == differential entropy
|
| 324 |
+
Mval = np.repeat(pdparam[None, :], N, axis=0)
|
| 325 |
+
M = probtype.param_placeholder([N])
|
| 326 |
+
X = probtype.sample_placeholder([N])
|
| 327 |
+
pd = probtype.pdfromflat(M)
|
| 328 |
+
calcloglik = U.function([X, M], pd.logp(X))
|
| 329 |
+
calcent = U.function([M], pd.entropy())
|
| 330 |
+
Xval = tf.compat.v1.get_default_session().run(pd.sample(), feed_dict={M:Mval})
|
| 331 |
+
logliks = calcloglik(Xval, Mval)
|
| 332 |
+
entval_ll = - logliks.mean() #pylint: disable=E1101
|
| 333 |
+
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
| 334 |
+
entval = calcent(Mval).mean() #pylint: disable=E1101
|
| 335 |
+
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
|
| 336 |
+
|
| 337 |
+
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
|
| 338 |
+
M2 = probtype.param_placeholder([N])
|
| 339 |
+
pd2 = probtype.pdfromflat(M2)
|
| 340 |
+
q = pdparam + np.random.randn(pdparam.size) * 0.1
|
| 341 |
+
Mval2 = np.repeat(q[None, :], N, axis=0)
|
| 342 |
+
calckl = U.function([M, M2], pd.kl(pd2))
|
| 343 |
+
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
|
| 344 |
+
logliks = calcloglik(Xval, Mval2)
|
| 345 |
+
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
|
| 346 |
+
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
| 347 |
+
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
|
| 348 |
+
print('ok on', probtype, pdparam)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def _matching_fc(tensor, name, size, init_scale, init_bias):
|
| 352 |
+
if tensor.shape[-1] == size:
|
| 353 |
+
return tensor
|
| 354 |
+
else:
|
| 355 |
+
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)
|
baselines/common/input.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
tf.compat.v1.disable_eager_execution()
|
| 4 |
+
from gym.spaces import Discrete, Box, MultiDiscrete
|
| 5 |
+
|
| 6 |
+
def observation_placeholder(ob_space, batch_size=None, name='Ob'):
|
| 7 |
+
'''
|
| 8 |
+
Create placeholder to feed observations into of the size appropriate to the observation space
|
| 9 |
+
|
| 10 |
+
Parameters:
|
| 11 |
+
----------
|
| 12 |
+
|
| 13 |
+
ob_space: gym.Space observation space
|
| 14 |
+
|
| 15 |
+
batch_size: int size of the batch to be fed into input. Can be left None in most cases.
|
| 16 |
+
|
| 17 |
+
name: str name of the placeholder
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
-------
|
| 21 |
+
|
| 22 |
+
tensorflow placeholder tensor
|
| 23 |
+
'''
|
| 24 |
+
|
| 25 |
+
assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
|
| 26 |
+
'Can only deal with Discrete and Box observation spaces for now'
|
| 27 |
+
|
| 28 |
+
dtype = ob_space.dtype
|
| 29 |
+
if dtype == np.int8:
|
| 30 |
+
dtype = np.uint8
|
| 31 |
+
|
| 32 |
+
return tf.compat.v1.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def observation_input(ob_space, batch_size=None, name='Ob'):
|
| 36 |
+
'''
|
| 37 |
+
Create placeholder to feed observations into of the size appropriate to the observation space, and add input
|
| 38 |
+
encoder of the appropriate type.
|
| 39 |
+
'''
|
| 40 |
+
|
| 41 |
+
placeholder = observation_placeholder(ob_space, batch_size, name)
|
| 42 |
+
return placeholder, encode_observation(ob_space, placeholder)
|
| 43 |
+
|
| 44 |
+
def encode_observation(ob_space, placeholder):
|
| 45 |
+
'''
|
| 46 |
+
Encode input in the way that is appropriate to the observation space
|
| 47 |
+
|
| 48 |
+
Parameters:
|
| 49 |
+
----------
|
| 50 |
+
|
| 51 |
+
ob_space: gym.Space observation space
|
| 52 |
+
|
| 53 |
+
placeholder: tf.placeholder observation input placeholder
|
| 54 |
+
'''
|
| 55 |
+
if isinstance(ob_space, Discrete):
|
| 56 |
+
return tf.cast(tf.one_hot(placeholder, ob_space.n), dtype=tf.float32)
|
| 57 |
+
elif isinstance(ob_space, Box):
|
| 58 |
+
return tf.cast(placeholder, dtype=tf.float32)
|
| 59 |
+
elif isinstance(ob_space, MultiDiscrete):
|
| 60 |
+
placeholder = tf.cast(placeholder, tf.int32)
|
| 61 |
+
one_hots = [tf.cast(tf.one_hot(placeholder[..., i], ob_space.nvec[i]), dtype=tf.float32) for i in range(placeholder.shape[-1])]
|
| 62 |
+
return tf.concat(one_hots, axis=-1)
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
baselines/common/math_util.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy.signal
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def discount(x, gamma):
|
| 6 |
+
"""
|
| 7 |
+
computes discounted sums along 0th dimension of x.
|
| 8 |
+
|
| 9 |
+
inputs
|
| 10 |
+
------
|
| 11 |
+
x: ndarray
|
| 12 |
+
gamma: float
|
| 13 |
+
|
| 14 |
+
outputs
|
| 15 |
+
-------
|
| 16 |
+
y: ndarray with same shape as x, satisfying
|
| 17 |
+
|
| 18 |
+
y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
|
| 19 |
+
where k = len(x) - t - 1
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
assert x.ndim >= 1
|
| 23 |
+
return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1]
|
| 24 |
+
|
| 25 |
+
def explained_variance(ypred,y):
|
| 26 |
+
"""
|
| 27 |
+
Computes fraction of variance that ypred explains about y.
|
| 28 |
+
Returns 1 - Var[y-ypred] / Var[y]
|
| 29 |
+
|
| 30 |
+
interpretation:
|
| 31 |
+
ev=0 => might as well have predicted zero
|
| 32 |
+
ev=1 => perfect prediction
|
| 33 |
+
ev<0 => worse than just predicting zero
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
assert y.ndim == 1 and ypred.ndim == 1
|
| 37 |
+
vary = np.var(y)
|
| 38 |
+
return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
|
| 39 |
+
|
| 40 |
+
def explained_variance_2d(ypred, y):
|
| 41 |
+
assert y.ndim == 2 and ypred.ndim == 2
|
| 42 |
+
vary = np.var(y, axis=0)
|
| 43 |
+
out = 1 - np.var(y-ypred)/vary
|
| 44 |
+
out[vary < 1e-10] = 0
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
def ncc(ypred, y):
|
| 48 |
+
return np.corrcoef(ypred, y)[1,0]
|
| 49 |
+
|
| 50 |
+
def flatten_arrays(arrs):
|
| 51 |
+
return np.concatenate([arr.flat for arr in arrs])
|
| 52 |
+
|
| 53 |
+
def unflatten_vector(vec, shapes):
|
| 54 |
+
i=0
|
| 55 |
+
arrs = []
|
| 56 |
+
for shape in shapes:
|
| 57 |
+
size = np.prod(shape)
|
| 58 |
+
arr = vec[i:i+size].reshape(shape)
|
| 59 |
+
arrs.append(arr)
|
| 60 |
+
i += size
|
| 61 |
+
return arrs
|
| 62 |
+
|
| 63 |
+
def discount_with_boundaries(X, New, gamma):
|
| 64 |
+
"""
|
| 65 |
+
X: 2d array of floats, time x features
|
| 66 |
+
New: 2d array of bools, indicating when a new episode has started
|
| 67 |
+
"""
|
| 68 |
+
Y = np.zeros_like(X)
|
| 69 |
+
T = X.shape[0]
|
| 70 |
+
Y[T-1] = X[T-1]
|
| 71 |
+
for t in range(T-2, -1, -1):
|
| 72 |
+
Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1])
|
| 73 |
+
return Y
|
| 74 |
+
|
| 75 |
+
def test_discount_with_boundaries():
|
| 76 |
+
gamma=0.9
|
| 77 |
+
x = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
|
| 78 |
+
starts = [1.0, 0.0, 0.0, 1.0]
|
| 79 |
+
y = discount_with_boundaries(x, starts, gamma)
|
| 80 |
+
assert np.allclose(y, [
|
| 81 |
+
1 + gamma * 2 + gamma**2 * 3,
|
| 82 |
+
2 + gamma * 3,
|
| 83 |
+
3,
|
| 84 |
+
4
|
| 85 |
+
])
|
baselines/common/misc_util.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
import tempfile
|
| 7 |
+
import zipfile
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def zipsame(*seqs):
|
| 11 |
+
L = len(seqs[0])
|
| 12 |
+
assert all(len(seq) == L for seq in seqs[1:])
|
| 13 |
+
return zip(*seqs)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EzPickle(object):
|
| 17 |
+
"""Objects that are pickled and unpickled via their constructor
|
| 18 |
+
arguments.
|
| 19 |
+
|
| 20 |
+
Example usage:
|
| 21 |
+
|
| 22 |
+
class Dog(Animal, EzPickle):
|
| 23 |
+
def __init__(self, furcolor, tailkind="bushy"):
|
| 24 |
+
Animal.__init__()
|
| 25 |
+
EzPickle.__init__(furcolor, tailkind)
|
| 26 |
+
...
|
| 27 |
+
|
| 28 |
+
When this object is unpickled, a new Dog will be constructed by passing the provided
|
| 29 |
+
furcolor and tailkind into the constructor. However, philosophers are still not sure
|
| 30 |
+
whether it is still the same dog.
|
| 31 |
+
|
| 32 |
+
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
|
| 33 |
+
and Atari.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, *args, **kwargs):
|
| 37 |
+
self._ezpickle_args = args
|
| 38 |
+
self._ezpickle_kwargs = kwargs
|
| 39 |
+
|
| 40 |
+
def __getstate__(self):
|
| 41 |
+
return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
|
| 42 |
+
|
| 43 |
+
def __setstate__(self, d):
|
| 44 |
+
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
|
| 45 |
+
self.__dict__.update(out.__dict__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def set_global_seeds(i):
|
| 49 |
+
try:
|
| 50 |
+
import MPI
|
| 51 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
| 52 |
+
except ImportError:
|
| 53 |
+
rank = 0
|
| 54 |
+
|
| 55 |
+
myseed = i + 1000 * rank if i is not None else None
|
| 56 |
+
try:
|
| 57 |
+
import tensorflow as tf
|
| 58 |
+
tf.compat.v1.set_random_seed(myseed)
|
| 59 |
+
except ImportError:
|
| 60 |
+
pass
|
| 61 |
+
np.random.seed(myseed)
|
| 62 |
+
random.seed(myseed)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def pretty_eta(seconds_left):
|
| 66 |
+
"""Print the number of seconds in human readable format.
|
| 67 |
+
|
| 68 |
+
Examples:
|
| 69 |
+
2 days
|
| 70 |
+
2 hours and 37 minutes
|
| 71 |
+
less than a minute
|
| 72 |
+
|
| 73 |
+
Paramters
|
| 74 |
+
---------
|
| 75 |
+
seconds_left: int
|
| 76 |
+
Number of seconds to be converted to the ETA
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
eta: str
|
| 80 |
+
String representing the pretty ETA.
|
| 81 |
+
"""
|
| 82 |
+
minutes_left = seconds_left // 60
|
| 83 |
+
seconds_left %= 60
|
| 84 |
+
hours_left = minutes_left // 60
|
| 85 |
+
minutes_left %= 60
|
| 86 |
+
days_left = hours_left // 24
|
| 87 |
+
hours_left %= 24
|
| 88 |
+
|
| 89 |
+
def helper(cnt, name):
|
| 90 |
+
return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
|
| 91 |
+
|
| 92 |
+
if days_left > 0:
|
| 93 |
+
msg = helper(days_left, 'day')
|
| 94 |
+
if hours_left > 0:
|
| 95 |
+
msg += ' and ' + helper(hours_left, 'hour')
|
| 96 |
+
return msg
|
| 97 |
+
if hours_left > 0:
|
| 98 |
+
msg = helper(hours_left, 'hour')
|
| 99 |
+
if minutes_left > 0:
|
| 100 |
+
msg += ' and ' + helper(minutes_left, 'minute')
|
| 101 |
+
return msg
|
| 102 |
+
if minutes_left > 0:
|
| 103 |
+
return helper(minutes_left, 'minute')
|
| 104 |
+
return 'less than a minute'
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class RunningAvg(object):
|
| 108 |
+
def __init__(self, gamma, init_value=None):
|
| 109 |
+
"""Keep a running estimate of a quantity. This is a bit like mean
|
| 110 |
+
but more sensitive to recent changes.
|
| 111 |
+
|
| 112 |
+
Parameters
|
| 113 |
+
----------
|
| 114 |
+
gamma: float
|
| 115 |
+
Must be between 0 and 1, where 0 is the most sensitive to recent
|
| 116 |
+
changes.
|
| 117 |
+
init_value: float or None
|
| 118 |
+
Initial value of the estimate. If None, it will be set on the first update.
|
| 119 |
+
"""
|
| 120 |
+
self._value = init_value
|
| 121 |
+
self._gamma = gamma
|
| 122 |
+
|
| 123 |
+
def update(self, new_val):
|
| 124 |
+
"""Update the estimate.
|
| 125 |
+
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
new_val: float
|
| 129 |
+
new observated value of estimated quantity.
|
| 130 |
+
"""
|
| 131 |
+
if self._value is None:
|
| 132 |
+
self._value = new_val
|
| 133 |
+
else:
|
| 134 |
+
self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
|
| 135 |
+
|
| 136 |
+
def __float__(self):
|
| 137 |
+
"""Get the current estimate"""
|
| 138 |
+
return self._value
|
| 139 |
+
|
| 140 |
+
def boolean_flag(parser, name, default=False, help=None):
|
| 141 |
+
"""Add a boolean flag to argparse parser.
|
| 142 |
+
|
| 143 |
+
Parameters
|
| 144 |
+
----------
|
| 145 |
+
parser: argparse.Parser
|
| 146 |
+
parser to add the flag to
|
| 147 |
+
name: str
|
| 148 |
+
--<name> will enable the flag, while --no-<name> will disable it
|
| 149 |
+
default: bool or None
|
| 150 |
+
default value of the flag
|
| 151 |
+
help: str
|
| 152 |
+
help string for the flag
|
| 153 |
+
"""
|
| 154 |
+
dest = name.replace('-', '_')
|
| 155 |
+
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
|
| 156 |
+
parser.add_argument("--no-" + name, action="store_false", dest=dest)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_wrapper_by_name(env, classname):
|
| 160 |
+
"""Given an a gym environment possibly wrapped multiple times, returns a wrapper
|
| 161 |
+
of class named classname or raises ValueError if no such wrapper was applied
|
| 162 |
+
|
| 163 |
+
Parameters
|
| 164 |
+
----------
|
| 165 |
+
env: gym.Env of gym.Wrapper
|
| 166 |
+
gym environment
|
| 167 |
+
classname: str
|
| 168 |
+
name of the wrapper
|
| 169 |
+
|
| 170 |
+
Returns
|
| 171 |
+
-------
|
| 172 |
+
wrapper: gym.Wrapper
|
| 173 |
+
wrapper named classname
|
| 174 |
+
"""
|
| 175 |
+
currentenv = env
|
| 176 |
+
while True:
|
| 177 |
+
if classname == currentenv.class_name():
|
| 178 |
+
return currentenv
|
| 179 |
+
elif isinstance(currentenv, gym.Wrapper):
|
| 180 |
+
currentenv = currentenv.env
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError("Couldn't find wrapper named %s" % classname)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def relatively_safe_pickle_dump(obj, path, compression=False):
|
| 186 |
+
"""This is just like regular pickle dump, except from the fact that failure cases are
|
| 187 |
+
different:
|
| 188 |
+
|
| 189 |
+
- It's never possible that we end up with a pickle in corrupted state.
|
| 190 |
+
- If a there was a different file at the path, that file will remain unchanged in the
|
| 191 |
+
even of failure (provided that filesystem rename is atomic).
|
| 192 |
+
- it is sometimes possible that we end up with useless temp file which needs to be
|
| 193 |
+
deleted manually (it will be removed automatically on the next function call)
|
| 194 |
+
|
| 195 |
+
The indended use case is periodic checkpoints of experiment state, such that we never
|
| 196 |
+
corrupt previous checkpoints if the current one fails.
|
| 197 |
+
|
| 198 |
+
Parameters
|
| 199 |
+
----------
|
| 200 |
+
obj: object
|
| 201 |
+
object to pickle
|
| 202 |
+
path: str
|
| 203 |
+
path to the output file
|
| 204 |
+
compression: bool
|
| 205 |
+
if true pickle will be compressed
|
| 206 |
+
"""
|
| 207 |
+
temp_storage = path + ".relatively_safe"
|
| 208 |
+
if compression:
|
| 209 |
+
# Using gzip here would be simpler, but the size is limited to 2GB
|
| 210 |
+
with tempfile.NamedTemporaryFile() as uncompressed_file:
|
| 211 |
+
pickle.dump(obj, uncompressed_file)
|
| 212 |
+
uncompressed_file.file.flush()
|
| 213 |
+
with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
|
| 214 |
+
myzip.write(uncompressed_file.name, "data")
|
| 215 |
+
else:
|
| 216 |
+
with open(temp_storage, "wb") as f:
|
| 217 |
+
pickle.dump(obj, f)
|
| 218 |
+
os.rename(temp_storage, path)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def pickle_load(path, compression=False):
|
| 222 |
+
"""Unpickle a possible compressed pickle.
|
| 223 |
+
|
| 224 |
+
Parameters
|
| 225 |
+
----------
|
| 226 |
+
path: str
|
| 227 |
+
path to the output file
|
| 228 |
+
compression: bool
|
| 229 |
+
if true assumes that pickle was compressed when created and attempts decompression.
|
| 230 |
+
|
| 231 |
+
Returns
|
| 232 |
+
-------
|
| 233 |
+
obj: object
|
| 234 |
+
the unpickled object
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
if compression:
|
| 238 |
+
with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
|
| 239 |
+
with myzip.open("data") as f:
|
| 240 |
+
return pickle.load(f)
|
| 241 |
+
else:
|
| 242 |
+
with open(path, "rb") as f:
|
| 243 |
+
return pickle.load(f)
|
baselines/common/models.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from baselines.a2c import utils
|
| 6 |
+
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch
|
| 7 |
+
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
| 8 |
+
from keras import layers
|
| 9 |
+
from itertools import combinations
|
| 10 |
+
|
| 11 |
+
mapping = {}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def register(name):
|
| 15 |
+
def _thunk(func):
|
| 16 |
+
mapping[name] = func
|
| 17 |
+
return func
|
| 18 |
+
|
| 19 |
+
return _thunk
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def nature_cnn(unscaled_images, **conv_kwargs):
|
| 23 |
+
"""
|
| 24 |
+
CNN from Nature paper.
|
| 25 |
+
"""
|
| 26 |
+
scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
|
| 27 |
+
activ = tf.nn.relu
|
| 28 |
+
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
|
| 29 |
+
**conv_kwargs))
|
| 30 |
+
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
| 31 |
+
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
|
| 32 |
+
h3 = conv_to_fc(h3)
|
| 33 |
+
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_impala_cnn(unscaled_images, depths=[16, 32, 32], **conv_kwargs):
|
| 37 |
+
"""
|
| 38 |
+
Model used in the paper "IMPALA: Scalable Distributed Deep-RL with
|
| 39 |
+
Importance Weighted Actor-Learner Architectures" https://arxiv.org/abs/1802.01561
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
layer_num = 0
|
| 43 |
+
|
| 44 |
+
def get_layer_num_str():
|
| 45 |
+
nonlocal layer_num
|
| 46 |
+
num_str = str(layer_num)
|
| 47 |
+
layer_num += 1
|
| 48 |
+
return num_str
|
| 49 |
+
|
| 50 |
+
def conv_layer(out, depth):
|
| 51 |
+
return tf.compat.v1.layers.conv2d(out, depth, 3, padding='same', name='layer_' + get_layer_num_str())
|
| 52 |
+
|
| 53 |
+
def residual_block(inputs):
|
| 54 |
+
try:
|
| 55 |
+
depth = inputs.get_shape()[-1].value
|
| 56 |
+
except:
|
| 57 |
+
depth = inputs.get_shape()[-1]
|
| 58 |
+
|
| 59 |
+
out = tf.nn.relu(inputs)
|
| 60 |
+
|
| 61 |
+
out = conv_layer(out, depth)
|
| 62 |
+
out = tf.nn.relu(out)
|
| 63 |
+
out = conv_layer(out, depth)
|
| 64 |
+
return out + inputs
|
| 65 |
+
|
| 66 |
+
def conv_sequence(inputs, depth):
|
| 67 |
+
out = conv_layer(inputs, depth)
|
| 68 |
+
out = tf.compat.v1.layers.max_pooling2d(out, pool_size=3, strides=2, padding='same')
|
| 69 |
+
out = residual_block(out)
|
| 70 |
+
out = residual_block(out)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
out = tf.cast(unscaled_images, tf.float32) / 255.
|
| 74 |
+
|
| 75 |
+
for depth in depths:
|
| 76 |
+
out = conv_sequence(out, depth)
|
| 77 |
+
|
| 78 |
+
out = tf.compat.v1.layers.flatten(out)
|
| 79 |
+
out = tf.nn.relu(out)
|
| 80 |
+
out = tf.compat.v1.layers.dense(out, 256, activation=tf.nn.relu, name='layer_' + get_layer_num_str())
|
| 81 |
+
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_skill_impala_cnn(unscaled_images, depths=[16, 32, 32], emb_dim=256, num_embeddings=8, seed=0,
|
| 86 |
+
**conv_kwargs):
|
| 87 |
+
"""
|
| 88 |
+
Modified impala cnn model by adding the skill module
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
layer_num = 0
|
| 92 |
+
|
| 93 |
+
def get_layer_num_str():
|
| 94 |
+
nonlocal layer_num
|
| 95 |
+
num_str = str(layer_num)
|
| 96 |
+
layer_num += 1
|
| 97 |
+
return num_str
|
| 98 |
+
|
| 99 |
+
def conv_layer(out, depth):
|
| 100 |
+
return tf.compat.v1.layers.conv2d(out, depth, 3, padding='same', name='layer_' + get_layer_num_str())
|
| 101 |
+
|
| 102 |
+
def residual_block(inputs):
|
| 103 |
+
# depth = inputs.get_shape()[-1].value
|
| 104 |
+
depth = inputs.get_shape()[-1]
|
| 105 |
+
|
| 106 |
+
out = tf.nn.relu(inputs)
|
| 107 |
+
|
| 108 |
+
out = conv_layer(out, depth)
|
| 109 |
+
out = tf.nn.relu(out)
|
| 110 |
+
out = conv_layer(out, depth)
|
| 111 |
+
return out + inputs
|
| 112 |
+
|
| 113 |
+
def conv_sequence(inputs, depth):
|
| 114 |
+
out = conv_layer(inputs, depth)
|
| 115 |
+
out = tf.compat.v1.layers.max_pooling2d(out, pool_size=3, strides=2, padding='same')
|
| 116 |
+
out = residual_block(out)
|
| 117 |
+
out = residual_block(out)
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
out = tf.cast(unscaled_images, tf.float32) / 255.
|
| 121 |
+
|
| 122 |
+
for depth in depths:
|
| 123 |
+
out = conv_sequence(out, depth)
|
| 124 |
+
|
| 125 |
+
out = tf.compat.v1.layers.flatten(out)
|
| 126 |
+
out = tf.nn.relu(out)
|
| 127 |
+
pure_out = tf.compat.v1.layers.dense(out, emb_dim, activation=tf.nn.relu, name='layer_' + get_layer_num_str())
|
| 128 |
+
|
| 129 |
+
# skill module
|
| 130 |
+
skill_out = tf.compat.v1.layers.dense(pure_out, emb_dim // 2, activation=None, name='layer_' + get_layer_num_str())
|
| 131 |
+
skill_out = tf.compat.v1.layers.dense(skill_out, 2, activation=None, name='layer_' + get_layer_num_str())
|
| 132 |
+
vq_layer = VectorQuantizer(num_embeddings, 2, seed=seed, name="vector_quantizer")
|
| 133 |
+
vq_out, pure_vq_out, encoding_indices = vq_layer(skill_out)
|
| 134 |
+
|
| 135 |
+
encoding_indices_ = tf.cast(
|
| 136 |
+
tf.tile(encoding_indices / vq_layer.num_embeddings, tf.constant([1, emb_dim], tf.int32)), tf.float32)
|
| 137 |
+
|
| 138 |
+
# add the normalized skill indices to features
|
| 139 |
+
out = tf.math.add(pure_out, encoding_indices_)
|
| 140 |
+
|
| 141 |
+
return out, skill_out, pure_out, vq_out, pure_vq_out, vq_layer.embeddings, encoding_indices
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@register("mlp")
|
| 145 |
+
def mlp(num_layers=2, num_hidden=64, activation=tf.tanh, layer_norm=False):
|
| 146 |
+
"""
|
| 147 |
+
Stack of fully-connected layers to be used in a policy / q-function approximator
|
| 148 |
+
|
| 149 |
+
Parameters:
|
| 150 |
+
----------
|
| 151 |
+
|
| 152 |
+
num_layers: int number of fully-connected layers (default: 2)
|
| 153 |
+
|
| 154 |
+
num_hidden: int size of fully-connected layers (default: 64)
|
| 155 |
+
|
| 156 |
+
activation: activation function (default: tf.tanh)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
-------
|
| 160 |
+
|
| 161 |
+
function that builds fully connected network with a given input tensor / placeholder
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def network_fn(X):
|
| 165 |
+
h = tf.compat.v1.layers.flatten(X)
|
| 166 |
+
for i in range(num_layers):
|
| 167 |
+
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden, init_scale=np.sqrt(2))
|
| 168 |
+
if layer_norm:
|
| 169 |
+
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
|
| 170 |
+
h = activation(h)
|
| 171 |
+
|
| 172 |
+
return h
|
| 173 |
+
|
| 174 |
+
return network_fn
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@register("cnn")
|
| 178 |
+
def cnn(**conv_kwargs):
|
| 179 |
+
def network_fn(X):
|
| 180 |
+
return nature_cnn(X, **conv_kwargs)
|
| 181 |
+
|
| 182 |
+
return network_fn
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@register("impala_cnn")
|
| 186 |
+
def impala_cnn(**conv_kwargs):
|
| 187 |
+
def network_fn(X):
|
| 188 |
+
return build_impala_cnn(X)
|
| 189 |
+
|
| 190 |
+
return network_fn
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@register("cnn_small")
|
| 194 |
+
def cnn_small(**conv_kwargs):
|
| 195 |
+
def network_fn(X):
|
| 196 |
+
h = tf.cast(X, tf.float32) / 255.
|
| 197 |
+
|
| 198 |
+
activ = tf.nn.relu
|
| 199 |
+
h = activ(conv(h, 'c1', nf=8, rf=8, stride=4, init_scale=np.sqrt(2), **conv_kwargs))
|
| 200 |
+
h = activ(conv(h, 'c2', nf=16, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
|
| 201 |
+
h = conv_to_fc(h)
|
| 202 |
+
h = activ(fc(h, 'fc1', nh=128, init_scale=np.sqrt(2)))
|
| 203 |
+
return h
|
| 204 |
+
|
| 205 |
+
return network_fn
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@register("lstm")
|
| 209 |
+
def lstm(nlstm=128, layer_norm=False):
|
| 210 |
+
"""
|
| 211 |
+
Builds LSTM (Long-Short Term Memory) network to be used in a policy.
|
| 212 |
+
Note that the resulting function returns not only the output of the LSTM
|
| 213 |
+
(i.e. hidden state of lstm for each step in the sequence), but also a dictionary
|
| 214 |
+
with auxiliary tensors to be set as policy attributes.
|
| 215 |
+
|
| 216 |
+
Specifically,
|
| 217 |
+
S is a placeholder to feed current state (LSTM state has to be managed outside policy)
|
| 218 |
+
M is a placeholder for the mask (used to mask out observations after the end of the episode, but can be used for other purposes too)
|
| 219 |
+
initial_state is a numpy array containing initial lstm state (usually zeros)
|
| 220 |
+
state is the output LSTM state (to be fed into S at the next call)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
An example of usage of lstm-based policy can be found here: common/tests/test_doc_examples.py/test_lstm_example
|
| 224 |
+
|
| 225 |
+
Parameters:
|
| 226 |
+
----------
|
| 227 |
+
|
| 228 |
+
nlstm: int LSTM hidden state size
|
| 229 |
+
|
| 230 |
+
layer_norm: bool if True, layer-normalized version of LSTM is used
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
-------
|
| 234 |
+
|
| 235 |
+
function that builds LSTM with a given input tensor / placeholder
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def network_fn(X, nenv=1):
|
| 239 |
+
nbatch = X.shape[0]
|
| 240 |
+
nsteps = nbatch // nenv
|
| 241 |
+
|
| 242 |
+
h = tf.compat.v1.layers.flatten(X)
|
| 243 |
+
|
| 244 |
+
M = tf.compat.v1.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
| 245 |
+
S = tf.compat.v1.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
| 246 |
+
|
| 247 |
+
xs = batch_to_seq(h, nenv, nsteps)
|
| 248 |
+
ms = batch_to_seq(M, nenv, nsteps)
|
| 249 |
+
|
| 250 |
+
if layer_norm:
|
| 251 |
+
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
| 252 |
+
else:
|
| 253 |
+
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
| 254 |
+
|
| 255 |
+
h = seq_to_batch(h5)
|
| 256 |
+
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
| 257 |
+
|
| 258 |
+
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
| 259 |
+
|
| 260 |
+
return network_fn
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@register("cnn_lstm")
|
| 264 |
+
def cnn_lstm(nlstm=128, layer_norm=False, conv_fn=nature_cnn, **conv_kwargs):
|
| 265 |
+
def network_fn(X, nenv=1):
|
| 266 |
+
nbatch = X.shape[0]
|
| 267 |
+
nsteps = nbatch // nenv
|
| 268 |
+
|
| 269 |
+
h = conv_fn(X, **conv_kwargs)
|
| 270 |
+
|
| 271 |
+
M = tf.compat.v1.placeholder(tf.float32, [nbatch]) # mask (done t-1)
|
| 272 |
+
S = tf.compat.v1.placeholder(tf.float32, [nenv, 2 * nlstm]) # states
|
| 273 |
+
|
| 274 |
+
xs = batch_to_seq(h, nenv, nsteps)
|
| 275 |
+
ms = batch_to_seq(M, nenv, nsteps)
|
| 276 |
+
|
| 277 |
+
if layer_norm:
|
| 278 |
+
h5, snew = utils.lnlstm(xs, ms, S, scope='lnlstm', nh=nlstm)
|
| 279 |
+
else:
|
| 280 |
+
h5, snew = utils.lstm(xs, ms, S, scope='lstm', nh=nlstm)
|
| 281 |
+
|
| 282 |
+
h = seq_to_batch(h5)
|
| 283 |
+
initial_state = np.zeros(S.shape.as_list(), dtype=float)
|
| 284 |
+
|
| 285 |
+
return h, {'S': S, 'M': M, 'state': snew, 'initial_state': initial_state}
|
| 286 |
+
|
| 287 |
+
return network_fn
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@register("impala_cnn_lstm")
|
| 291 |
+
def impala_cnn_lstm():
|
| 292 |
+
return cnn_lstm(nlstm=256, conv_fn=build_impala_cnn)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@register("cnn_lnlstm")
|
| 296 |
+
def cnn_lnlstm(nlstm=128, **conv_kwargs):
|
| 297 |
+
return cnn_lstm(nlstm, layer_norm=True, **conv_kwargs)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@register("conv_only")
|
| 301 |
+
def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
| 302 |
+
'''
|
| 303 |
+
convolutions-only net
|
| 304 |
+
|
| 305 |
+
Parameters:
|
| 306 |
+
----------
|
| 307 |
+
|
| 308 |
+
conv: list of triples (filter_number, filter_size, stride) specifying parameters for each layer.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
|
| 312 |
+
function that takes tensorflow tensor as input and returns the output of the last convolutional layer
|
| 313 |
+
|
| 314 |
+
'''
|
| 315 |
+
|
| 316 |
+
def network_fn(X):
|
| 317 |
+
out = tf.cast(X, tf.float32) / 255.
|
| 318 |
+
with tf.compat.v1.variable_scope("convnet"):
|
| 319 |
+
for num_outputs, kernel_size, stride in convs:
|
| 320 |
+
out = tf.contrib.layers.convolution2d(out,
|
| 321 |
+
num_outputs=num_outputs,
|
| 322 |
+
kernel_size=kernel_size,
|
| 323 |
+
stride=stride,
|
| 324 |
+
activation_fn=tf.nn.relu,
|
| 325 |
+
**conv_kwargs)
|
| 326 |
+
|
| 327 |
+
return out
|
| 328 |
+
|
| 329 |
+
return network_fn
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
| 333 |
+
rms = RunningMeanStd(shape=x.shape[1:])
|
| 334 |
+
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
| 335 |
+
return norm_x, rms
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def get_network_builder(name):
|
| 339 |
+
"""
|
| 340 |
+
If you want to register your own network outside models.py, you just need:
|
| 341 |
+
|
| 342 |
+
Usage Example:
|
| 343 |
+
-------------
|
| 344 |
+
from baselines.common.models import register
|
| 345 |
+
@register("your_network_name")
|
| 346 |
+
def your_network_define(**net_kwargs):
|
| 347 |
+
...
|
| 348 |
+
return network_fn
|
| 349 |
+
|
| 350 |
+
"""
|
| 351 |
+
if callable(name):
|
| 352 |
+
return name
|
| 353 |
+
elif name in mapping:
|
| 354 |
+
return mapping[name]
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError('Unknown network type: {}'.format(name))
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class VectorQuantizer(layers.Layer):
|
| 360 |
+
def __init__(self, num_embeddings, embedding_dim, seed=0, **kwargs):
|
| 361 |
+
super().__init__(**kwargs)
|
| 362 |
+
self.embedding_dim = embedding_dim
|
| 363 |
+
self.num_embeddings = num_embeddings
|
| 364 |
+
|
| 365 |
+
# Initialize the embeddings which we will quantize.
|
| 366 |
+
w_init = tf.compat.v1.random_uniform_initializer(minval=-1 / num_embeddings, maxval=1 / num_embeddings,
|
| 367 |
+
seed=seed)
|
| 368 |
+
|
| 369 |
+
self.embeddings = tf.compat.v1.get_variable(
|
| 370 |
+
initializer=w_init(
|
| 371 |
+
shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
|
| 372 |
+
),
|
| 373 |
+
trainable=True,
|
| 374 |
+
name="embeddings_vqvae",
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def call(self, x):
|
| 378 |
+
# Calculate the input shape of the inputs and
|
| 379 |
+
# then flatten the inputs keeping `embedding_dim` intact.
|
| 380 |
+
input_shape = tf.shape(input=x)
|
| 381 |
+
flattened = tf.reshape(x, [-1, self.embedding_dim])
|
| 382 |
+
|
| 383 |
+
# Quantization.
|
| 384 |
+
encoding_indices = self.get_code_indices(flattened)
|
| 385 |
+
encoding_indices = tf.reshape(encoding_indices, [input_shape[0], -1])
|
| 386 |
+
encodings = tf.one_hot(encoding_indices, self.num_embeddings)
|
| 387 |
+
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
|
| 388 |
+
quantized = tf.reshape(quantized, input_shape)
|
| 389 |
+
|
| 390 |
+
# Straight-through estimator.
|
| 391 |
+
quantized_ = x + tf.stop_gradient(quantized - x)
|
| 392 |
+
|
| 393 |
+
return quantized_, quantized, encoding_indices
|
| 394 |
+
|
| 395 |
+
def get_code_indices(self, flattened_inputs):
|
| 396 |
+
# Calculate L2-normalized distance between the inputs and the codes.
|
| 397 |
+
similarity = tf.matmul(flattened_inputs, self.embeddings)
|
| 398 |
+
distances = (
|
| 399 |
+
tf.reduce_sum(input_tensor=flattened_inputs ** 2, axis=1, keepdims=True)
|
| 400 |
+
+ tf.reduce_sum(input_tensor=self.embeddings ** 2, axis=0)
|
| 401 |
+
- 2 * similarity
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Derive the indices for minimum distances.
|
| 405 |
+
encoding_indices = tf.argmin(input=distances, axis=1)
|
| 406 |
+
return encoding_indices
|
baselines/common/mpi_adam.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import baselines.common.tf_util as U
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import numpy as np
|
| 4 |
+
try:
|
| 5 |
+
from mpi4py import MPI
|
| 6 |
+
except ImportError:
|
| 7 |
+
MPI = None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MpiAdam(object):
|
| 11 |
+
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
| 12 |
+
self.var_list = var_list
|
| 13 |
+
self.beta1 = beta1
|
| 14 |
+
self.beta2 = beta2
|
| 15 |
+
self.epsilon = epsilon
|
| 16 |
+
self.scale_grad_by_procs = scale_grad_by_procs
|
| 17 |
+
size = sum(U.numel(v) for v in var_list)
|
| 18 |
+
self.m = np.zeros(size, 'float32')
|
| 19 |
+
self.v = np.zeros(size, 'float32')
|
| 20 |
+
self.t = 0
|
| 21 |
+
self.setfromflat = U.SetFromFlat(var_list)
|
| 22 |
+
self.getflat = U.GetFlat(var_list)
|
| 23 |
+
self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
|
| 24 |
+
|
| 25 |
+
def update(self, localg, stepsize):
|
| 26 |
+
if self.t % 100 == 0:
|
| 27 |
+
self.check_synced()
|
| 28 |
+
localg = localg.astype('float32')
|
| 29 |
+
if self.comm is not None:
|
| 30 |
+
globalg = np.zeros_like(localg)
|
| 31 |
+
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
| 32 |
+
if self.scale_grad_by_procs:
|
| 33 |
+
globalg /= self.comm.Get_size()
|
| 34 |
+
else:
|
| 35 |
+
globalg = np.copy(localg)
|
| 36 |
+
|
| 37 |
+
self.t += 1
|
| 38 |
+
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
| 39 |
+
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
| 40 |
+
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
| 41 |
+
step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon)
|
| 42 |
+
self.setfromflat(self.getflat() + step)
|
| 43 |
+
|
| 44 |
+
def sync(self):
|
| 45 |
+
if self.comm is None:
|
| 46 |
+
return
|
| 47 |
+
theta = self.getflat()
|
| 48 |
+
self.comm.Bcast(theta, root=0)
|
| 49 |
+
self.setfromflat(theta)
|
| 50 |
+
|
| 51 |
+
def check_synced(self):
|
| 52 |
+
if self.comm is None:
|
| 53 |
+
return
|
| 54 |
+
if self.comm.Get_rank() == 0: # this is root
|
| 55 |
+
theta = self.getflat()
|
| 56 |
+
self.comm.Bcast(theta, root=0)
|
| 57 |
+
else:
|
| 58 |
+
thetalocal = self.getflat()
|
| 59 |
+
thetaroot = np.empty_like(thetalocal)
|
| 60 |
+
self.comm.Bcast(thetaroot, root=0)
|
| 61 |
+
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
| 62 |
+
|
| 63 |
+
@U.in_session
|
| 64 |
+
def test_MpiAdam():
|
| 65 |
+
np.random.seed(0)
|
| 66 |
+
tf.compat.v1.set_random_seed(0)
|
| 67 |
+
|
| 68 |
+
a = tf.Variable(np.random.randn(3).astype('float32'))
|
| 69 |
+
b = tf.Variable(np.random.randn(2,5).astype('float32'))
|
| 70 |
+
loss = tf.reduce_sum(input_tensor=tf.square(a)) + tf.reduce_sum(input_tensor=tf.sin(b))
|
| 71 |
+
|
| 72 |
+
stepsize = 1e-2
|
| 73 |
+
update_op = tf.compat.v1.train.AdamOptimizer(stepsize).minimize(loss)
|
| 74 |
+
do_update = U.function([], loss, updates=[update_op])
|
| 75 |
+
|
| 76 |
+
tf.compat.v1.get_default_session().run(tf.compat.v1.global_variables_initializer())
|
| 77 |
+
losslist_ref = []
|
| 78 |
+
for i in range(10):
|
| 79 |
+
l = do_update()
|
| 80 |
+
print(i, l)
|
| 81 |
+
losslist_ref.append(l)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
tf.compat.v1.set_random_seed(0)
|
| 86 |
+
tf.compat.v1.get_default_session().run(tf.compat.v1.global_variables_initializer())
|
| 87 |
+
|
| 88 |
+
var_list = [a,b]
|
| 89 |
+
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
|
| 90 |
+
adam = MpiAdam(var_list)
|
| 91 |
+
|
| 92 |
+
losslist_test = []
|
| 93 |
+
for i in range(10):
|
| 94 |
+
l,g = lossandgrad()
|
| 95 |
+
adam.update(g, stepsize)
|
| 96 |
+
print(i,l)
|
| 97 |
+
losslist_test.append(l)
|
| 98 |
+
|
| 99 |
+
np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
test_MpiAdam()
|
baselines/common/mpi_adam_optimizer.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
from baselines.common import tf_util as U
|
| 4 |
+
from baselines.common.tests.test_with_mpi import with_mpi
|
| 5 |
+
from baselines import logger
|
| 6 |
+
try:
|
| 7 |
+
from mpi4py import MPI
|
| 8 |
+
except ImportError:
|
| 9 |
+
MPI = None
|
| 10 |
+
|
| 11 |
+
class MpiAdamOptimizer(tf.compat.v1.train.AdamOptimizer):
|
| 12 |
+
"""Adam optimizer that averages gradients across mpi processes."""
|
| 13 |
+
def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
|
| 14 |
+
self.comm = comm
|
| 15 |
+
self.grad_clip = grad_clip
|
| 16 |
+
self.mpi_rank_weight = mpi_rank_weight
|
| 17 |
+
tf.compat.v1.train.AdamOptimizer.__init__(self, **kwargs)
|
| 18 |
+
def compute_gradients(self, loss, var_list, **kwargs):
|
| 19 |
+
grads_and_vars = tf.compat.v1.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
|
| 20 |
+
grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
|
| 21 |
+
flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight
|
| 22 |
+
shapes = [v.shape.as_list() for g, v in grads_and_vars]
|
| 23 |
+
sizes = [int(np.prod(s)) for s in shapes]
|
| 24 |
+
|
| 25 |
+
total_weight = np.zeros(1, np.float32)
|
| 26 |
+
self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM)
|
| 27 |
+
total_weight = total_weight[0]
|
| 28 |
+
|
| 29 |
+
buf = np.zeros(sum(sizes), np.float32)
|
| 30 |
+
countholder = [0] # Counts how many times _collect_grads has been called
|
| 31 |
+
stat = tf.reduce_sum(input_tensor=grads_and_vars[0][1]) # sum of first variable
|
| 32 |
+
def _collect_grads(flat_grad, np_stat):
|
| 33 |
+
if self.grad_clip is not None:
|
| 34 |
+
gradnorm = np.linalg.norm(flat_grad)
|
| 35 |
+
if gradnorm > 1:
|
| 36 |
+
flat_grad /= gradnorm
|
| 37 |
+
logger.logkv_mean('gradnorm', gradnorm)
|
| 38 |
+
logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
|
| 39 |
+
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
| 40 |
+
np.divide(buf, float(total_weight), out=buf)
|
| 41 |
+
if countholder[0] % 100 == 0:
|
| 42 |
+
check_synced(np_stat, self.comm)
|
| 43 |
+
countholder[0] += 1
|
| 44 |
+
return buf
|
| 45 |
+
|
| 46 |
+
avg_flat_grad = tf.compat.v1.py_func(_collect_grads, [flat_grad, stat], tf.float32)
|
| 47 |
+
avg_flat_grad.set_shape(flat_grad.shape)
|
| 48 |
+
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
| 49 |
+
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
| 50 |
+
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
| 51 |
+
return avg_grads_and_vars
|
| 52 |
+
|
| 53 |
+
def check_synced(localval, comm=None):
|
| 54 |
+
"""
|
| 55 |
+
It's common to forget to initialize your variables to the same values, or
|
| 56 |
+
(less commonly) if you update them in some other way than adam, to get them out of sync.
|
| 57 |
+
This function checks that variables on all MPI workers are the same, and raises
|
| 58 |
+
an AssertionError otherwise
|
| 59 |
+
|
| 60 |
+
Arguments:
|
| 61 |
+
comm: MPI communicator
|
| 62 |
+
localval: list of local variables (list of variables on current worker to be compared with the other workers)
|
| 63 |
+
"""
|
| 64 |
+
comm = comm or MPI.COMM_WORLD
|
| 65 |
+
vals = comm.gather(localval)
|
| 66 |
+
if comm.rank == 0:
|
| 67 |
+
assert all(val==vals[0] for val in vals[1:]),\
|
| 68 |
+
'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals)
|
| 69 |
+
|
| 70 |
+
@with_mpi(timeout=5)
|
| 71 |
+
def test_nonfreeze():
|
| 72 |
+
np.random.seed(0)
|
| 73 |
+
tf.compat.v1.set_random_seed(0)
|
| 74 |
+
|
| 75 |
+
a = tf.Variable(np.random.randn(3).astype('float32'))
|
| 76 |
+
b = tf.Variable(np.random.randn(2,5).astype('float32'))
|
| 77 |
+
loss = tf.reduce_sum(input_tensor=tf.square(a)) + tf.reduce_sum(input_tensor=tf.sin(b))
|
| 78 |
+
|
| 79 |
+
stepsize = 1e-2
|
| 80 |
+
# for some reason the session config with inter_op_parallelism_threads was causing
|
| 81 |
+
# nested sess.run calls to freeze
|
| 82 |
+
config = tf.compat.v1.ConfigProto(inter_op_parallelism_threads=1)
|
| 83 |
+
sess = U.get_session(config=config)
|
| 84 |
+
update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss)
|
| 85 |
+
sess.run(tf.compat.v1.global_variables_initializer())
|
| 86 |
+
losslist_ref = []
|
| 87 |
+
for i in range(100):
|
| 88 |
+
l,_ = sess.run([loss, update_op])
|
| 89 |
+
print(i, l)
|
| 90 |
+
losslist_ref.append(l)
|
baselines/common/mpi_fork.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, subprocess, sys
|
| 2 |
+
|
| 3 |
+
def mpi_fork(n, bind_to_core=False):
|
| 4 |
+
"""Re-launches the current script with workers
|
| 5 |
+
Returns "parent" for original parent, "child" for MPI children
|
| 6 |
+
"""
|
| 7 |
+
if n<=1:
|
| 8 |
+
return "child"
|
| 9 |
+
if os.getenv("IN_MPI") is None:
|
| 10 |
+
env = os.environ.copy()
|
| 11 |
+
env.update(
|
| 12 |
+
MKL_NUM_THREADS="1",
|
| 13 |
+
OMP_NUM_THREADS="1",
|
| 14 |
+
IN_MPI="1"
|
| 15 |
+
)
|
| 16 |
+
args = ["mpirun", "-np", str(n)]
|
| 17 |
+
if bind_to_core:
|
| 18 |
+
args += ["-bind-to", "core"]
|
| 19 |
+
args += [sys.executable] + sys.argv
|
| 20 |
+
subprocess.check_call(args, env=env)
|
| 21 |
+
return "parent"
|
| 22 |
+
else:
|
| 23 |
+
return "child"
|
baselines/common/mpi_moments.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from mpi4py import MPI
|
| 2 |
+
import numpy as np
|
| 3 |
+
from baselines.common import zipsame
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def mpi_mean(x, axis=0, comm=None, keepdims=False):
|
| 7 |
+
x = np.asarray(x)
|
| 8 |
+
assert x.ndim > 0
|
| 9 |
+
if comm is None: comm = MPI.COMM_WORLD
|
| 10 |
+
xsum = x.sum(axis=axis, keepdims=keepdims)
|
| 11 |
+
n = xsum.size
|
| 12 |
+
localsum = np.zeros(n+1, x.dtype)
|
| 13 |
+
localsum[:n] = xsum.ravel()
|
| 14 |
+
localsum[n] = x.shape[axis]
|
| 15 |
+
# globalsum = np.zeros_like(localsum)
|
| 16 |
+
# comm.Allreduce(localsum, globalsum, op=MPI.SUM)
|
| 17 |
+
globalsum = comm.allreduce(localsum, op=MPI.SUM)
|
| 18 |
+
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
|
| 19 |
+
|
| 20 |
+
def mpi_moments(x, axis=0, comm=None, keepdims=False):
|
| 21 |
+
x = np.asarray(x)
|
| 22 |
+
assert x.ndim > 0
|
| 23 |
+
mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True)
|
| 24 |
+
sqdiffs = np.square(x - mean)
|
| 25 |
+
meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True)
|
| 26 |
+
assert count1 == count
|
| 27 |
+
std = np.sqrt(meansqdiff)
|
| 28 |
+
if not keepdims:
|
| 29 |
+
newshape = mean.shape[:axis] + mean.shape[axis+1:]
|
| 30 |
+
mean = mean.reshape(newshape)
|
| 31 |
+
std = std.reshape(newshape)
|
| 32 |
+
return mean, std, count
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_runningmeanstd():
|
| 36 |
+
import subprocess
|
| 37 |
+
subprocess.check_call(['mpirun', '-np', '3',
|
| 38 |
+
'python','-c',
|
| 39 |
+
'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()'])
|
| 40 |
+
|
| 41 |
+
def _helper_runningmeanstd():
|
| 42 |
+
comm = MPI.COMM_WORLD
|
| 43 |
+
np.random.seed(0)
|
| 44 |
+
for (triple,axis) in [
|
| 45 |
+
((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0),
|
| 46 |
+
((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0),
|
| 47 |
+
((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1),
|
| 48 |
+
]:
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
x = np.concatenate(triple, axis=axis)
|
| 52 |
+
ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis)
|
| 56 |
+
|
| 57 |
+
for (a1,a2) in zipsame(ms1, ms2):
|
| 58 |
+
print(a1, a2)
|
| 59 |
+
assert np.allclose(a1, a2)
|
| 60 |
+
print("ok!")
|
| 61 |
+
|
baselines/common/mpi_running_mean_std.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
from mpi4py import MPI
|
| 3 |
+
except ImportError:
|
| 4 |
+
MPI = None
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
|
| 7 |
+
|
| 8 |
+
class RunningMeanStd(object):
|
| 9 |
+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
| 10 |
+
def __init__(self, epsilon=1e-2, shape=()):
|
| 11 |
+
|
| 12 |
+
self._sum = tf.compat.v1.get_variable(
|
| 13 |
+
dtype=tf.float64,
|
| 14 |
+
shape=shape,
|
| 15 |
+
initializer=tf.compat.v1.constant_initializer(0.0),
|
| 16 |
+
name="runningsum", trainable=False)
|
| 17 |
+
self._sumsq = tf.compat.v1.get_variable(
|
| 18 |
+
dtype=tf.float64,
|
| 19 |
+
shape=shape,
|
| 20 |
+
initializer=tf.compat.v1.constant_initializer(epsilon),
|
| 21 |
+
name="runningsumsq", trainable=False)
|
| 22 |
+
self._count = tf.compat.v1.get_variable(
|
| 23 |
+
dtype=tf.float64,
|
| 24 |
+
shape=(),
|
| 25 |
+
initializer=tf.compat.v1.constant_initializer(epsilon),
|
| 26 |
+
name="count", trainable=False)
|
| 27 |
+
self.shape = shape
|
| 28 |
+
|
| 29 |
+
self.mean = tf.cast(self._sum / self._count, dtype=tf.float32)
|
| 30 |
+
self.std = tf.sqrt( tf.maximum( tf.cast(self._sumsq / self._count, dtype=tf.float32) - tf.square(self.mean) , 1e-2 ))
|
| 31 |
+
|
| 32 |
+
newsum = tf.compat.v1.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
|
| 33 |
+
newsumsq = tf.compat.v1.placeholder(shape=self.shape, dtype=tf.float64, name='var')
|
| 34 |
+
newcount = tf.compat.v1.placeholder(shape=[], dtype=tf.float64, name='count')
|
| 35 |
+
self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
|
| 36 |
+
updates=[tf.compat.v1.assign_add(self._sum, newsum),
|
| 37 |
+
tf.compat.v1.assign_add(self._sumsq, newsumsq),
|
| 38 |
+
tf.compat.v1.assign_add(self._count, newcount)])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def update(self, x):
|
| 42 |
+
x = x.astype('float64')
|
| 43 |
+
n = int(np.prod(self.shape))
|
| 44 |
+
totalvec = np.zeros(n*2+1, 'float64')
|
| 45 |
+
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
|
| 46 |
+
if MPI is not None:
|
| 47 |
+
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
| 48 |
+
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
|
| 49 |
+
|
| 50 |
+
@U.in_session
|
| 51 |
+
def test_runningmeanstd():
|
| 52 |
+
for (x1, x2, x3) in [
|
| 53 |
+
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
| 54 |
+
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
|
| 55 |
+
]:
|
| 56 |
+
|
| 57 |
+
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
|
| 58 |
+
U.initialize()
|
| 59 |
+
|
| 60 |
+
x = np.concatenate([x1, x2, x3], axis=0)
|
| 61 |
+
ms1 = [x.mean(axis=0), x.std(axis=0)]
|
| 62 |
+
rms.update(x1)
|
| 63 |
+
rms.update(x2)
|
| 64 |
+
rms.update(x3)
|
| 65 |
+
ms2 = [rms.mean.eval(), rms.std.eval()]
|
| 66 |
+
|
| 67 |
+
assert np.allclose(ms1, ms2)
|
| 68 |
+
|
| 69 |
+
@U.in_session
|
| 70 |
+
def test_dist():
|
| 71 |
+
np.random.seed(0)
|
| 72 |
+
p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
|
| 73 |
+
q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
|
| 74 |
+
|
| 75 |
+
# p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
|
| 76 |
+
# q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
|
| 77 |
+
|
| 78 |
+
comm = MPI.COMM_WORLD
|
| 79 |
+
assert comm.Get_size()==2
|
| 80 |
+
if comm.Get_rank()==0:
|
| 81 |
+
x1,x2,x3 = p1,p2,p3
|
| 82 |
+
elif comm.Get_rank()==1:
|
| 83 |
+
x1,x2,x3 = q1,q2,q3
|
| 84 |
+
else:
|
| 85 |
+
assert False
|
| 86 |
+
|
| 87 |
+
rms = RunningMeanStd(epsilon=0.0, shape=(1,))
|
| 88 |
+
U.initialize()
|
| 89 |
+
|
| 90 |
+
rms.update(x1)
|
| 91 |
+
rms.update(x2)
|
| 92 |
+
rms.update(x3)
|
| 93 |
+
|
| 94 |
+
bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
|
| 95 |
+
|
| 96 |
+
def checkallclose(x,y):
|
| 97 |
+
print(x,y)
|
| 98 |
+
return np.allclose(x,y)
|
| 99 |
+
|
| 100 |
+
assert checkallclose(
|
| 101 |
+
bigvec.mean(axis=0),
|
| 102 |
+
rms.mean.eval(),
|
| 103 |
+
)
|
| 104 |
+
assert checkallclose(
|
| 105 |
+
bigvec.std(axis=0),
|
| 106 |
+
rms.std.eval(),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
# Run with mpirun -np 2 python <filename>
|
| 112 |
+
test_dist()
|
baselines/common/mpi_util.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import os, numpy as np
|
| 3 |
+
import platform
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
import warnings
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from mpi4py import MPI
|
| 11 |
+
except ImportError:
|
| 12 |
+
MPI = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def sync_from_root(sess, variables, comm=None):
|
| 16 |
+
"""
|
| 17 |
+
Send the root node's parameters to every worker.
|
| 18 |
+
Arguments:
|
| 19 |
+
sess: the TensorFlow session.
|
| 20 |
+
variables: all parameter variables including optimizer's
|
| 21 |
+
"""
|
| 22 |
+
if comm is None: comm = MPI.COMM_WORLD
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
values = comm.bcast(sess.run(variables))
|
| 25 |
+
sess.run([tf.compat.v1.assign(var, val)
|
| 26 |
+
for (var, val) in zip(variables, values)])
|
| 27 |
+
|
| 28 |
+
def gpu_count():
|
| 29 |
+
"""
|
| 30 |
+
Count the GPUs on this machine.
|
| 31 |
+
"""
|
| 32 |
+
if shutil.which('nvidia-smi') is None:
|
| 33 |
+
return 0
|
| 34 |
+
output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
|
| 35 |
+
return max(0, len(output.split(b'\n')) - 2)
|
| 36 |
+
|
| 37 |
+
def setup_mpi_gpus():
|
| 38 |
+
"""
|
| 39 |
+
Set CUDA_VISIBLE_DEVICES to MPI rank if not already set
|
| 40 |
+
"""
|
| 41 |
+
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
| 42 |
+
if sys.platform == 'darwin': # This Assumes if you're on OSX you're just
|
| 43 |
+
ids = [] # doing a smoke test and don't want GPUs
|
| 44 |
+
else:
|
| 45 |
+
lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD)
|
| 46 |
+
ids = [lrank]
|
| 47 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids))
|
| 48 |
+
|
| 49 |
+
def get_local_rank_size(comm):
|
| 50 |
+
"""
|
| 51 |
+
Returns the rank of each process on its machine
|
| 52 |
+
The processes on a given machine will be assigned ranks
|
| 53 |
+
0, 1, 2, ..., N-1,
|
| 54 |
+
where N is the number of processes on this machine.
|
| 55 |
+
|
| 56 |
+
Useful if you want to assign one gpu per machine
|
| 57 |
+
"""
|
| 58 |
+
this_node = platform.node()
|
| 59 |
+
ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
|
| 60 |
+
node2rankssofar = defaultdict(int)
|
| 61 |
+
local_rank = None
|
| 62 |
+
for (rank, node) in ranks_nodes:
|
| 63 |
+
if rank == comm.Get_rank():
|
| 64 |
+
local_rank = node2rankssofar[node]
|
| 65 |
+
node2rankssofar[node] += 1
|
| 66 |
+
assert local_rank is not None
|
| 67 |
+
return local_rank, node2rankssofar[this_node]
|
| 68 |
+
|
| 69 |
+
def share_file(comm, path):
|
| 70 |
+
"""
|
| 71 |
+
Copies the file from rank 0 to all other ranks
|
| 72 |
+
Puts it in the same place on all machines
|
| 73 |
+
"""
|
| 74 |
+
localrank, _ = get_local_rank_size(comm)
|
| 75 |
+
if comm.Get_rank() == 0:
|
| 76 |
+
with open(path, 'rb') as fh:
|
| 77 |
+
data = fh.read()
|
| 78 |
+
comm.bcast(data)
|
| 79 |
+
else:
|
| 80 |
+
data = comm.bcast(None)
|
| 81 |
+
if localrank == 0:
|
| 82 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 83 |
+
with open(path, 'wb') as fh:
|
| 84 |
+
fh.write(data)
|
| 85 |
+
comm.Barrier()
|
| 86 |
+
|
| 87 |
+
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
| 88 |
+
"""
|
| 89 |
+
Perform a reduction operation over dicts
|
| 90 |
+
"""
|
| 91 |
+
if comm is None: return d
|
| 92 |
+
alldicts = comm.allgather(d)
|
| 93 |
+
size = comm.size
|
| 94 |
+
k2li = defaultdict(list)
|
| 95 |
+
for d in alldicts:
|
| 96 |
+
for (k,v) in d.items():
|
| 97 |
+
k2li[k].append(v)
|
| 98 |
+
result = {}
|
| 99 |
+
for (k,li) in k2li.items():
|
| 100 |
+
if assert_all_have_data:
|
| 101 |
+
assert len(li)==size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k)
|
| 102 |
+
if op=='mean':
|
| 103 |
+
result[k] = np.mean(li, axis=0)
|
| 104 |
+
elif op=='sum':
|
| 105 |
+
result[k] = np.sum(li, axis=0)
|
| 106 |
+
else:
|
| 107 |
+
assert 0, op
|
| 108 |
+
return result
|
| 109 |
+
|
| 110 |
+
def mpi_weighted_mean(comm, local_name2valcount):
|
| 111 |
+
"""
|
| 112 |
+
Perform a weighted average over dicts that are each on a different node
|
| 113 |
+
Input: local_name2valcount: dict mapping key -> (value, count)
|
| 114 |
+
Returns: key -> mean
|
| 115 |
+
"""
|
| 116 |
+
all_name2valcount = comm.gather(local_name2valcount)
|
| 117 |
+
if comm.rank == 0:
|
| 118 |
+
name2sum = defaultdict(float)
|
| 119 |
+
name2count = defaultdict(float)
|
| 120 |
+
for n2vc in all_name2valcount:
|
| 121 |
+
for (name, (val, count)) in n2vc.items():
|
| 122 |
+
try:
|
| 123 |
+
val = float(val)
|
| 124 |
+
except ValueError:
|
| 125 |
+
if comm.rank == 0:
|
| 126 |
+
warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val))
|
| 127 |
+
else:
|
| 128 |
+
name2sum[name] += val * count
|
| 129 |
+
name2count[name] += count
|
| 130 |
+
return {name : name2sum[name] / name2count[name] for name in name2sum}
|
| 131 |
+
else:
|
| 132 |
+
return {}
|
| 133 |
+
|
baselines/common/plot_util.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas
|
| 7 |
+
from collections import defaultdict, namedtuple
|
| 8 |
+
from baselines.bench import monitor
|
| 9 |
+
from baselines.logger import read_json, read_csv
|
| 10 |
+
|
| 11 |
+
def smooth(y, radius, mode='two_sided', valid_only=False):
|
| 12 |
+
'''
|
| 13 |
+
Smooth signal y, where radius is determines the size of the window
|
| 14 |
+
|
| 15 |
+
mode='twosided':
|
| 16 |
+
average over the window [max(index - radius, 0), min(index + radius, len(y)-1)]
|
| 17 |
+
mode='causal':
|
| 18 |
+
average over the window [max(index - radius, 0), index]
|
| 19 |
+
|
| 20 |
+
valid_only: put nan in entries where the full-sized window is not available
|
| 21 |
+
|
| 22 |
+
'''
|
| 23 |
+
assert mode in ('two_sided', 'causal')
|
| 24 |
+
if len(y) < 2*radius+1:
|
| 25 |
+
return np.ones_like(y) * y.mean()
|
| 26 |
+
elif mode == 'two_sided':
|
| 27 |
+
convkernel = np.ones(2 * radius+1)
|
| 28 |
+
out = np.convolve(y, convkernel,mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same')
|
| 29 |
+
if valid_only:
|
| 30 |
+
out[:radius] = out[-radius:] = np.nan
|
| 31 |
+
elif mode == 'causal':
|
| 32 |
+
convkernel = np.ones(radius)
|
| 33 |
+
out = np.convolve(y, convkernel,mode='full') / np.convolve(np.ones_like(y), convkernel, mode='full')
|
| 34 |
+
out = out[:-radius+1]
|
| 35 |
+
if valid_only:
|
| 36 |
+
out[:radius] = np.nan
|
| 37 |
+
return out
|
| 38 |
+
|
| 39 |
+
def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
|
| 40 |
+
'''
|
| 41 |
+
perform one-sided (causal) EMA (exponential moving average)
|
| 42 |
+
smoothing and resampling to an even grid with n points.
|
| 43 |
+
Does not do extrapolation, so we assume
|
| 44 |
+
xolds[0] <= low && high <= xolds[-1]
|
| 45 |
+
|
| 46 |
+
Arguments:
|
| 47 |
+
|
| 48 |
+
xolds: array or list - x values of data. Needs to be sorted in ascending order
|
| 49 |
+
yolds: array of list - y values of data. Has to have the same length as xolds
|
| 50 |
+
|
| 51 |
+
low: float - min value of the new x grid. By default equals to xolds[0]
|
| 52 |
+
high: float - max value of the new x grid. By default equals to xolds[-1]
|
| 53 |
+
|
| 54 |
+
n: int - number of points in new x grid
|
| 55 |
+
|
| 56 |
+
decay_steps: float - EMA decay factor, expressed in new x grid steps.
|
| 57 |
+
|
| 58 |
+
low_counts_threshold: float or int
|
| 59 |
+
- y values with counts less than this value will be set to NaN
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
tuple sum_ys, count_ys where
|
| 63 |
+
xs - array with new x grid
|
| 64 |
+
ys - array of EMA of y at each point of the new x grid
|
| 65 |
+
count_ys - array of EMA of y counts at each point of the new x grid
|
| 66 |
+
|
| 67 |
+
'''
|
| 68 |
+
|
| 69 |
+
low = xolds[0] if low is None else low
|
| 70 |
+
high = xolds[-1] if high is None else high
|
| 71 |
+
|
| 72 |
+
assert xolds[0] <= low, 'low = {} < xolds[0] = {} - extrapolation not permitted!'.format(low, xolds[0])
|
| 73 |
+
assert xolds[-1] >= high, 'high = {} > xolds[-1] = {} - extrapolation not permitted!'.format(high, xolds[-1])
|
| 74 |
+
assert len(xolds) == len(yolds), 'length of xolds ({}) and yolds ({}) do not match!'.format(len(xolds), len(yolds))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
xolds = xolds.astype('float64')
|
| 78 |
+
yolds = yolds.astype('float64')
|
| 79 |
+
|
| 80 |
+
luoi = 0 # last unused old index
|
| 81 |
+
sum_y = 0.
|
| 82 |
+
count_y = 0.
|
| 83 |
+
xnews = np.linspace(low, high, n)
|
| 84 |
+
decay_period = (high - low) / (n - 1) * decay_steps
|
| 85 |
+
interstep_decay = np.exp(- 1. / decay_steps)
|
| 86 |
+
sum_ys = np.zeros_like(xnews)
|
| 87 |
+
count_ys = np.zeros_like(xnews)
|
| 88 |
+
for i in range(n):
|
| 89 |
+
xnew = xnews[i]
|
| 90 |
+
sum_y *= interstep_decay
|
| 91 |
+
count_y *= interstep_decay
|
| 92 |
+
while True:
|
| 93 |
+
if luoi >= len(xolds):
|
| 94 |
+
break
|
| 95 |
+
xold = xolds[luoi]
|
| 96 |
+
if xold <= xnew:
|
| 97 |
+
decay = np.exp(- (xnew - xold) / decay_period)
|
| 98 |
+
sum_y += decay * yolds[luoi]
|
| 99 |
+
count_y += decay
|
| 100 |
+
luoi += 1
|
| 101 |
+
else:
|
| 102 |
+
break
|
| 103 |
+
sum_ys[i] = sum_y
|
| 104 |
+
count_ys[i] = count_y
|
| 105 |
+
|
| 106 |
+
ys = sum_ys / count_ys
|
| 107 |
+
ys[count_ys < low_counts_threshold] = np.nan
|
| 108 |
+
|
| 109 |
+
return xnews, ys, count_ys
|
| 110 |
+
|
| 111 |
+
def symmetric_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_counts_threshold=1e-8):
|
| 112 |
+
'''
|
| 113 |
+
perform symmetric EMA (exponential moving average)
|
| 114 |
+
smoothing and resampling to an even grid with n points.
|
| 115 |
+
Does not do extrapolation, so we assume
|
| 116 |
+
xolds[0] <= low && high <= xolds[-1]
|
| 117 |
+
|
| 118 |
+
Arguments:
|
| 119 |
+
|
| 120 |
+
xolds: array or list - x values of data. Needs to be sorted in ascending order
|
| 121 |
+
yolds: array of list - y values of data. Has to have the same length as xolds
|
| 122 |
+
|
| 123 |
+
low: float - min value of the new x grid. By default equals to xolds[0]
|
| 124 |
+
high: float - max value of the new x grid. By default equals to xolds[-1]
|
| 125 |
+
|
| 126 |
+
n: int - number of points in new x grid
|
| 127 |
+
|
| 128 |
+
decay_steps: float - EMA decay factor, expressed in new x grid steps.
|
| 129 |
+
|
| 130 |
+
low_counts_threshold: float or int
|
| 131 |
+
- y values with counts less than this value will be set to NaN
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
tuple sum_ys, count_ys where
|
| 135 |
+
xs - array with new x grid
|
| 136 |
+
ys - array of EMA of y at each point of the new x grid
|
| 137 |
+
count_ys - array of EMA of y counts at each point of the new x grid
|
| 138 |
+
|
| 139 |
+
'''
|
| 140 |
+
xs, ys1, count_ys1 = one_sided_ema(xolds, yolds, low, high, n, decay_steps, low_counts_threshold=0)
|
| 141 |
+
_, ys2, count_ys2 = one_sided_ema(-xolds[::-1], yolds[::-1], -high, -low, n, decay_steps, low_counts_threshold=0)
|
| 142 |
+
ys2 = ys2[::-1]
|
| 143 |
+
count_ys2 = count_ys2[::-1]
|
| 144 |
+
count_ys = count_ys1 + count_ys2
|
| 145 |
+
ys = (ys1 * count_ys1 + ys2 * count_ys2) / count_ys
|
| 146 |
+
ys[count_ys < low_counts_threshold] = np.nan
|
| 147 |
+
return xs, ys, count_ys
|
| 148 |
+
|
| 149 |
+
Result = namedtuple('Result', 'monitor progress dirname metadata')
|
| 150 |
+
Result.__new__.__defaults__ = (None,) * len(Result._fields)
|
| 151 |
+
|
| 152 |
+
def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, verbose=False):
|
| 153 |
+
'''
|
| 154 |
+
load summaries of runs from a list of directories (including subdirectories)
|
| 155 |
+
Arguments:
|
| 156 |
+
|
| 157 |
+
enable_progress: bool - if True, will attempt to load data from progress.csv files (data saved by logger). Default: True
|
| 158 |
+
|
| 159 |
+
enable_monitor: bool - if True, will attempt to load data from monitor.csv files (data saved by Monitor environment wrapper). Default: True
|
| 160 |
+
|
| 161 |
+
verbose: bool - if True, will print out list of directories from which the data is loaded. Default: False
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
List of Result objects with the following fields:
|
| 166 |
+
- dirname - path to the directory data was loaded from
|
| 167 |
+
- metadata - run metadata (such as command-line arguments and anything else in metadata.json file
|
| 168 |
+
- monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory)
|
| 169 |
+
- progress - if enable_progress is True, this field contains pandas dataframe with loaded progress.csv file
|
| 170 |
+
'''
|
| 171 |
+
import re
|
| 172 |
+
if isinstance(root_dir_or_dirs, str):
|
| 173 |
+
rootdirs = [osp.expanduser(root_dir_or_dirs)]
|
| 174 |
+
else:
|
| 175 |
+
rootdirs = [osp.expanduser(d) for d in root_dir_or_dirs]
|
| 176 |
+
allresults = []
|
| 177 |
+
for rootdir in rootdirs:
|
| 178 |
+
assert osp.exists(rootdir), "%s doesn't exist"%rootdir
|
| 179 |
+
for dirname, dirs, files in os.walk(rootdir):
|
| 180 |
+
if '-proc' in dirname:
|
| 181 |
+
files[:] = []
|
| 182 |
+
continue
|
| 183 |
+
monitor_re = re.compile(r'(\d+\.)?(\d+\.)?monitor\.csv')
|
| 184 |
+
if set(['metadata.json', 'monitor.json', 'progress.json', 'progress.csv']).intersection(files) or \
|
| 185 |
+
any([f for f in files if monitor_re.match(f)]): # also match monitor files like 0.1.monitor.csv
|
| 186 |
+
# used to be uncommented, which means do not go deeper than current directory if any of the data files
|
| 187 |
+
# are found
|
| 188 |
+
# dirs[:] = []
|
| 189 |
+
result = {'dirname' : dirname}
|
| 190 |
+
if "metadata.json" in files:
|
| 191 |
+
with open(osp.join(dirname, "metadata.json"), "r") as fh:
|
| 192 |
+
result['metadata'] = json.load(fh)
|
| 193 |
+
progjson = osp.join(dirname, "progress.json")
|
| 194 |
+
progcsv = osp.join(dirname, "progress.csv")
|
| 195 |
+
if enable_progress:
|
| 196 |
+
if osp.exists(progjson):
|
| 197 |
+
result['progress'] = pandas.DataFrame(read_json(progjson))
|
| 198 |
+
elif osp.exists(progcsv):
|
| 199 |
+
try:
|
| 200 |
+
result['progress'] = read_csv(progcsv)
|
| 201 |
+
except pandas.errors.EmptyDataError:
|
| 202 |
+
print('skipping progress file in ', dirname, 'empty data')
|
| 203 |
+
else:
|
| 204 |
+
if verbose: print('skipping %s: no progress file'%dirname)
|
| 205 |
+
|
| 206 |
+
if enable_monitor:
|
| 207 |
+
try:
|
| 208 |
+
result['monitor'] = pandas.DataFrame(monitor.load_results(dirname))
|
| 209 |
+
except monitor.LoadMonitorResultsError:
|
| 210 |
+
print('skipping %s: no monitor files'%dirname)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print('exception loading monitor file in %s: %s'%(dirname, e))
|
| 213 |
+
|
| 214 |
+
if result.get('monitor') is not None or result.get('progress') is not None:
|
| 215 |
+
allresults.append(Result(**result))
|
| 216 |
+
if verbose:
|
| 217 |
+
print('successfully loaded %s'%dirname)
|
| 218 |
+
|
| 219 |
+
if verbose: print('loaded %i results'%len(allresults))
|
| 220 |
+
return allresults
|
| 221 |
+
|
| 222 |
+
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
|
| 223 |
+
'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
|
| 224 |
+
'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def default_xy_fn(r):
|
| 228 |
+
x = np.cumsum(r.monitor.l)
|
| 229 |
+
y = smooth(r.monitor.r, radius=10)
|
| 230 |
+
return x,y
|
| 231 |
+
|
| 232 |
+
def default_split_fn(r):
|
| 233 |
+
import re
|
| 234 |
+
# match name between slash and -<digits> at the end of the string
|
| 235 |
+
# (slash in the beginning or -<digits> in the end or either may be missing)
|
| 236 |
+
match = re.search(r'[^/-]+(?=(-\d+)?\Z)', r.dirname)
|
| 237 |
+
if match:
|
| 238 |
+
return match.group(0)
|
| 239 |
+
|
| 240 |
+
def plot_results(
|
| 241 |
+
allresults, *,
|
| 242 |
+
xy_fn=default_xy_fn,
|
| 243 |
+
split_fn=default_split_fn,
|
| 244 |
+
group_fn=default_split_fn,
|
| 245 |
+
average_group=False,
|
| 246 |
+
shaded_std=True,
|
| 247 |
+
shaded_err=True,
|
| 248 |
+
figsize=None,
|
| 249 |
+
legend_outside=False,
|
| 250 |
+
resample=0,
|
| 251 |
+
smooth_step=1.0,
|
| 252 |
+
tiling='vertical',
|
| 253 |
+
xlabel=None,
|
| 254 |
+
ylabel=None
|
| 255 |
+
):
|
| 256 |
+
'''
|
| 257 |
+
Plot multiple Results objects
|
| 258 |
+
|
| 259 |
+
xy_fn: function Result -> x,y - function that converts results objects into tuple of x and y values.
|
| 260 |
+
By default, x is cumsum of episode lengths, and y is episode rewards
|
| 261 |
+
|
| 262 |
+
split_fn: function Result -> hashable - function that converts results objects into keys to split curves into sub-panels by.
|
| 263 |
+
That is, the results r for which split_fn(r) is different will be put on different sub-panels.
|
| 264 |
+
By default, the portion of r.dirname between last / and -<digits> is returned. The sub-panels are
|
| 265 |
+
stacked vertically in the figure.
|
| 266 |
+
|
| 267 |
+
group_fn: function Result -> hashable - function that converts results objects into keys to group curves by.
|
| 268 |
+
That is, the results r for which group_fn(r) is the same will be put into the same group.
|
| 269 |
+
Curves in the same group have the same color (if average_group is False), or averaged over
|
| 270 |
+
(if average_group is True). The default value is the same as default value for split_fn
|
| 271 |
+
|
| 272 |
+
average_group: bool - if True, will average the curves in the same group and plot the mean. Enables resampling
|
| 273 |
+
(if resample = 0, will use 512 steps)
|
| 274 |
+
|
| 275 |
+
shaded_std: bool - if True (default), the shaded region corresponding to standard deviation of the group of curves will be
|
| 276 |
+
shown (only applicable if average_group = True)
|
| 277 |
+
|
| 278 |
+
shaded_err: bool - if True (default), the shaded region corresponding to error in mean estimate of the group of curves
|
| 279 |
+
(that is, standard deviation divided by square root of number of curves) will be
|
| 280 |
+
shown (only applicable if average_group = True)
|
| 281 |
+
|
| 282 |
+
figsize: tuple or None - size of the resulting figure (including sub-panels). By default, width is 6 and height is 6 times number of
|
| 283 |
+
sub-panels.
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
legend_outside: bool - if True, will place the legend outside of the sub-panels.
|
| 287 |
+
|
| 288 |
+
resample: int - if not zero, size of the uniform grid in x direction to resample onto. Resampling is performed via symmetric
|
| 289 |
+
EMA smoothing (see the docstring for symmetric_ema).
|
| 290 |
+
Default is zero (no resampling). Note that if average_group is True, resampling is necessary; in that case, default
|
| 291 |
+
value is 512.
|
| 292 |
+
|
| 293 |
+
smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step).
|
| 294 |
+
See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.
|
| 295 |
+
|
| 296 |
+
'''
|
| 297 |
+
|
| 298 |
+
if split_fn is None: split_fn = lambda _ : ''
|
| 299 |
+
if group_fn is None: group_fn = lambda _ : ''
|
| 300 |
+
sk2r = defaultdict(list) # splitkey2results
|
| 301 |
+
for result in allresults:
|
| 302 |
+
splitkey = split_fn(result)
|
| 303 |
+
sk2r[splitkey].append(result)
|
| 304 |
+
assert len(sk2r) > 0
|
| 305 |
+
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
|
| 306 |
+
if tiling == 'vertical' or tiling is None:
|
| 307 |
+
nrows = len(sk2r)
|
| 308 |
+
ncols = 1
|
| 309 |
+
elif tiling == 'horizontal':
|
| 310 |
+
ncols = len(sk2r)
|
| 311 |
+
nrows = 1
|
| 312 |
+
elif tiling == 'symmetric':
|
| 313 |
+
import math
|
| 314 |
+
N = len(sk2r)
|
| 315 |
+
largest_divisor = 1
|
| 316 |
+
for i in range(1, int(math.sqrt(N))+1):
|
| 317 |
+
if N % i == 0:
|
| 318 |
+
largest_divisor = i
|
| 319 |
+
ncols = largest_divisor
|
| 320 |
+
nrows = N // ncols
|
| 321 |
+
figsize = figsize or (6 * ncols, 6 * nrows)
|
| 322 |
+
|
| 323 |
+
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
|
| 324 |
+
|
| 325 |
+
groups = list(set(group_fn(result) for result in allresults))
|
| 326 |
+
|
| 327 |
+
default_samples = 512
|
| 328 |
+
if average_group:
|
| 329 |
+
resample = resample or default_samples
|
| 330 |
+
|
| 331 |
+
for (isplit, sk) in enumerate(sorted(sk2r.keys())):
|
| 332 |
+
g2l = {}
|
| 333 |
+
g2c = defaultdict(int)
|
| 334 |
+
sresults = sk2r[sk]
|
| 335 |
+
gresults = defaultdict(list)
|
| 336 |
+
idx_row = isplit // ncols
|
| 337 |
+
idx_col = isplit % ncols
|
| 338 |
+
ax = axarr[idx_row][idx_col]
|
| 339 |
+
for result in sresults:
|
| 340 |
+
group = group_fn(result)
|
| 341 |
+
g2c[group] += 1
|
| 342 |
+
x, y = xy_fn(result)
|
| 343 |
+
if x is None: x = np.arange(len(y))
|
| 344 |
+
x, y = map(np.asarray, (x, y))
|
| 345 |
+
if average_group:
|
| 346 |
+
gresults[group].append((x,y))
|
| 347 |
+
else:
|
| 348 |
+
if resample:
|
| 349 |
+
x, y, counts = symmetric_ema(x, y, x[0], x[-1], resample, decay_steps=smooth_step)
|
| 350 |
+
l, = ax.plot(x, y, color=COLORS[groups.index(group) % len(COLORS)])
|
| 351 |
+
g2l[group] = l
|
| 352 |
+
if average_group:
|
| 353 |
+
for group in sorted(groups):
|
| 354 |
+
xys = gresults[group]
|
| 355 |
+
if not any(xys):
|
| 356 |
+
continue
|
| 357 |
+
color = COLORS[groups.index(group) % len(COLORS)]
|
| 358 |
+
origxs = [xy[0] for xy in xys]
|
| 359 |
+
minxlen = min(map(len, origxs))
|
| 360 |
+
def allequal(qs):
|
| 361 |
+
return all((q==qs[0]).all() for q in qs[1:])
|
| 362 |
+
if resample:
|
| 363 |
+
low = max(x[0] for x in origxs)
|
| 364 |
+
high = min(x[-1] for x in origxs)
|
| 365 |
+
usex = np.linspace(low, high, resample)
|
| 366 |
+
ys = []
|
| 367 |
+
for (x, y) in xys:
|
| 368 |
+
ys.append(symmetric_ema(x, y, low, high, resample, decay_steps=smooth_step)[1])
|
| 369 |
+
else:
|
| 370 |
+
assert allequal([x[:minxlen] for x in origxs]),\
|
| 371 |
+
'If you want to average unevenly sampled data, set resample=<number of samples you want>'
|
| 372 |
+
usex = origxs[0]
|
| 373 |
+
ys = [xy[1][:minxlen] for xy in xys]
|
| 374 |
+
ymean = np.mean(ys, axis=0)
|
| 375 |
+
ystd = np.std(ys, axis=0)
|
| 376 |
+
ystderr = ystd / np.sqrt(len(ys))
|
| 377 |
+
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
|
| 378 |
+
g2l[group] = l
|
| 379 |
+
if shaded_err:
|
| 380 |
+
ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4)
|
| 381 |
+
if shaded_std:
|
| 382 |
+
ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# https://matplotlib.org/users/legend_guide.html
|
| 386 |
+
plt.tight_layout()
|
| 387 |
+
if any(g2l.keys()):
|
| 388 |
+
ax.legend(
|
| 389 |
+
g2l.values(),
|
| 390 |
+
['%s (%i)'%(g, g2c[g]) for g in g2l] if average_group else g2l.keys(),
|
| 391 |
+
loc=2 if legend_outside else None,
|
| 392 |
+
bbox_to_anchor=(1,1) if legend_outside else None)
|
| 393 |
+
ax.set_title(sk)
|
| 394 |
+
# add xlabels, but only to the bottom row
|
| 395 |
+
if xlabel is not None:
|
| 396 |
+
for ax in axarr[-1]:
|
| 397 |
+
plt.sca(ax)
|
| 398 |
+
plt.xlabel(xlabel)
|
| 399 |
+
# add ylabels, but only to left column
|
| 400 |
+
if ylabel is not None:
|
| 401 |
+
for ax in axarr[:,0]:
|
| 402 |
+
plt.sca(ax)
|
| 403 |
+
plt.ylabel(ylabel)
|
| 404 |
+
|
| 405 |
+
return f, axarr
|
| 406 |
+
|
| 407 |
+
def regression_analysis(df):
|
| 408 |
+
xcols = list(df.columns.copy())
|
| 409 |
+
xcols.remove('score')
|
| 410 |
+
ycols = ['score']
|
| 411 |
+
import statsmodels.api as sm
|
| 412 |
+
mod = sm.OLS(df[ycols], sm.add_constant(df[xcols]), hasconst=False)
|
| 413 |
+
res = mod.fit()
|
| 414 |
+
print(res.summary())
|
| 415 |
+
|
| 416 |
+
def test_smooth():
|
| 417 |
+
norig = 100
|
| 418 |
+
nup = 300
|
| 419 |
+
ndown = 30
|
| 420 |
+
xs = np.cumsum(np.random.rand(norig) * 10 / norig)
|
| 421 |
+
yclean = np.sin(xs)
|
| 422 |
+
ys = yclean + .1 * np.random.randn(yclean.size)
|
| 423 |
+
xup, yup, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), nup, decay_steps=nup/ndown)
|
| 424 |
+
xdown, ydown, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), ndown, decay_steps=ndown/ndown)
|
| 425 |
+
xsame, ysame, _ = symmetric_ema(xs, ys, xs.min(), xs.max(), norig, decay_steps=norig/ndown)
|
| 426 |
+
plt.plot(xs, ys, label='orig', marker='x')
|
| 427 |
+
plt.plot(xup, yup, label='up', marker='x')
|
| 428 |
+
plt.plot(xdown, ydown, label='down', marker='x')
|
| 429 |
+
plt.plot(xsame, ysame, label='same', marker='x')
|
| 430 |
+
plt.plot(xs, yclean, label='clean', marker='x')
|
| 431 |
+
plt.legend()
|
| 432 |
+
plt.show()
|
| 433 |
+
|
| 434 |
+
|
baselines/common/policies.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from baselines.common import tf_util
|
| 3 |
+
from baselines.a2c.utils import fc
|
| 4 |
+
from baselines.common.distributions import make_pdtype
|
| 5 |
+
from baselines.common.input import observation_placeholder, encode_observation
|
| 6 |
+
from baselines.common.tf_util import adjust_shape
|
| 7 |
+
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
| 8 |
+
from baselines.common.models import get_network_builder
|
| 9 |
+
|
| 10 |
+
import gym
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PolicyWithValue(object):
|
| 14 |
+
"""
|
| 15 |
+
Encapsulates fields and methods for RL policy and value function estimation with shared parameters
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, **tensors):
|
| 19 |
+
"""
|
| 20 |
+
Parameters:
|
| 21 |
+
----------
|
| 22 |
+
env RL environment
|
| 23 |
+
|
| 24 |
+
observations tensorflow placeholder in which the observations will be fed
|
| 25 |
+
|
| 26 |
+
latent latent state from which policy distribution parameters should be inferred
|
| 27 |
+
|
| 28 |
+
vf_latent latent state from which value function should be inferred (if None, then latent is used)
|
| 29 |
+
|
| 30 |
+
sess tensorflow session to run calculations in (if None, default session is used)
|
| 31 |
+
|
| 32 |
+
**tensors tensorflow tensors for additional attributes such as state or mask
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
self.X = observations
|
| 37 |
+
self.state = tf.constant([])
|
| 38 |
+
self.initial_state = None
|
| 39 |
+
self.__dict__.update(tensors)
|
| 40 |
+
|
| 41 |
+
vf_latent = vf_latent if vf_latent is not None else latent
|
| 42 |
+
|
| 43 |
+
vf_latent = tf.compat.v1.layers.flatten(vf_latent)
|
| 44 |
+
latent = tf.compat.v1.layers.flatten(latent)
|
| 45 |
+
|
| 46 |
+
# Based on the action space, will select what probability distribution type
|
| 47 |
+
self.pdtype = make_pdtype(env.action_space)
|
| 48 |
+
|
| 49 |
+
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
|
| 50 |
+
|
| 51 |
+
# Take an action
|
| 52 |
+
self.action = self.pd.sample()
|
| 53 |
+
|
| 54 |
+
# Calculate the neg log of our probability
|
| 55 |
+
self.neglogp = self.pd.neglogp(self.action)
|
| 56 |
+
self.sess = sess or tf.compat.v1.get_default_session()
|
| 57 |
+
|
| 58 |
+
if estimate_q:
|
| 59 |
+
assert isinstance(env.action_space, gym.spaces.Discrete)
|
| 60 |
+
self.q = fc(vf_latent, 'q', env.action_space.n)
|
| 61 |
+
self.vf = self.q
|
| 62 |
+
else:
|
| 63 |
+
self.vf = fc(vf_latent, 'vf', 1)
|
| 64 |
+
self.vf = self.vf[:,0]
|
| 65 |
+
|
| 66 |
+
def _evaluate(self, variables, observation, **extra_feed):
|
| 67 |
+
sess = self.sess
|
| 68 |
+
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
| 69 |
+
for inpt_name, data in extra_feed.items():
|
| 70 |
+
if inpt_name in self.__dict__.keys():
|
| 71 |
+
inpt = self.__dict__[inpt_name]
|
| 72 |
+
if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder':
|
| 73 |
+
feed_dict[inpt] = adjust_shape(inpt, data)
|
| 74 |
+
|
| 75 |
+
return sess.run(variables, feed_dict)
|
| 76 |
+
|
| 77 |
+
def step(self, observation, **extra_feed):
|
| 78 |
+
"""
|
| 79 |
+
Compute next action(s) given the observation(s)
|
| 80 |
+
|
| 81 |
+
Parameters:
|
| 82 |
+
----------
|
| 83 |
+
|
| 84 |
+
observation observation data (either single or a batch)
|
| 85 |
+
|
| 86 |
+
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
-------
|
| 90 |
+
(action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, **extra_feed)
|
| 94 |
+
if state.size == 0:
|
| 95 |
+
state = None
|
| 96 |
+
return a, v, state, neglogp
|
| 97 |
+
|
| 98 |
+
def value(self, ob, *args, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
Compute value estimate(s) given the observation(s)
|
| 101 |
+
|
| 102 |
+
Parameters:
|
| 103 |
+
----------
|
| 104 |
+
|
| 105 |
+
observation observation data (either single or a batch)
|
| 106 |
+
|
| 107 |
+
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
-------
|
| 111 |
+
value estimate
|
| 112 |
+
"""
|
| 113 |
+
return self._evaluate(self.vf, ob, *args, **kwargs)
|
| 114 |
+
|
| 115 |
+
def save(self, save_path):
|
| 116 |
+
tf_util.save_state(save_path, sess=self.sess)
|
| 117 |
+
|
| 118 |
+
def load(self, load_path):
|
| 119 |
+
tf_util.load_state(load_path, sess=self.sess)
|
| 120 |
+
|
| 121 |
+
def build_policy(env, policy_network, value_network=None, normalize_observations=False, estimate_q=False, **policy_kwargs):
|
| 122 |
+
if isinstance(policy_network, str):
|
| 123 |
+
network_type = policy_network
|
| 124 |
+
policy_network = get_network_builder(network_type)(**policy_kwargs)
|
| 125 |
+
|
| 126 |
+
def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
|
| 127 |
+
ob_space = env.observation_space
|
| 128 |
+
|
| 129 |
+
X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch)
|
| 130 |
+
|
| 131 |
+
extra_tensors = {}
|
| 132 |
+
|
| 133 |
+
if normalize_observations and X.dtype == tf.float32:
|
| 134 |
+
encoded_x, rms = _normalize_clip_observation(X)
|
| 135 |
+
extra_tensors['rms'] = rms
|
| 136 |
+
else:
|
| 137 |
+
encoded_x = X
|
| 138 |
+
|
| 139 |
+
encoded_x = encode_observation(ob_space, encoded_x)
|
| 140 |
+
|
| 141 |
+
with tf.compat.v1.variable_scope('pi', reuse=tf.compat.v1.AUTO_REUSE):
|
| 142 |
+
policy_latent = policy_network(encoded_x)
|
| 143 |
+
if isinstance(policy_latent, tuple):
|
| 144 |
+
policy_latent, recurrent_tensors = policy_latent
|
| 145 |
+
|
| 146 |
+
if recurrent_tensors is not None:
|
| 147 |
+
# recurrent architecture, need a few more steps
|
| 148 |
+
nenv = nbatch // nsteps
|
| 149 |
+
assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
|
| 150 |
+
policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
|
| 151 |
+
extra_tensors.update(recurrent_tensors)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
_v_net = value_network
|
| 155 |
+
|
| 156 |
+
if _v_net is None or _v_net == 'shared':
|
| 157 |
+
vf_latent = policy_latent
|
| 158 |
+
else:
|
| 159 |
+
if _v_net == 'copy':
|
| 160 |
+
_v_net = policy_network
|
| 161 |
+
else:
|
| 162 |
+
assert callable(_v_net)
|
| 163 |
+
|
| 164 |
+
with tf.compat.v1.variable_scope('vf', reuse=tf.compat.v1.AUTO_REUSE):
|
| 165 |
+
# TODO recurrent architectures are not supported with value_network=copy yet
|
| 166 |
+
vf_latent = _v_net(encoded_x)
|
| 167 |
+
|
| 168 |
+
policy = PolicyWithValue(
|
| 169 |
+
env=env,
|
| 170 |
+
observations=X,
|
| 171 |
+
latent=policy_latent,
|
| 172 |
+
vf_latent=vf_latent,
|
| 173 |
+
sess=sess,
|
| 174 |
+
estimate_q=estimate_q,
|
| 175 |
+
**extra_tensors
|
| 176 |
+
)
|
| 177 |
+
return policy
|
| 178 |
+
|
| 179 |
+
return policy_fn
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
| 183 |
+
rms = RunningMeanStd(shape=x.shape[1:])
|
| 184 |
+
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
| 185 |
+
return norm_x, rms
|
| 186 |
+
|
baselines/common/retro_wrappers.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import cv2
|
| 3 |
+
cv2.ocl.setUseOpenCL(False)
|
| 4 |
+
from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
|
| 5 |
+
from .wrappers import TimeLimit
|
| 6 |
+
import numpy as np
|
| 7 |
+
import gym
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class StochasticFrameSkip(gym.Wrapper):
|
| 11 |
+
def __init__(self, env, n, stickprob):
|
| 12 |
+
gym.Wrapper.__init__(self, env)
|
| 13 |
+
self.n = n
|
| 14 |
+
self.stickprob = stickprob
|
| 15 |
+
self.curac = None
|
| 16 |
+
self.rng = np.random.RandomState()
|
| 17 |
+
self.supports_want_render = hasattr(env, "supports_want_render")
|
| 18 |
+
|
| 19 |
+
def reset(self, **kwargs):
|
| 20 |
+
self.curac = None
|
| 21 |
+
return self.env.reset(**kwargs)
|
| 22 |
+
|
| 23 |
+
def step(self, ac):
|
| 24 |
+
done = False
|
| 25 |
+
totrew = 0
|
| 26 |
+
for i in range(self.n):
|
| 27 |
+
# First step after reset, use action
|
| 28 |
+
if self.curac is None:
|
| 29 |
+
self.curac = ac
|
| 30 |
+
# First substep, delay with probability=stickprob
|
| 31 |
+
elif i==0:
|
| 32 |
+
if self.rng.rand() > self.stickprob:
|
| 33 |
+
self.curac = ac
|
| 34 |
+
# Second substep, new action definitely kicks in
|
| 35 |
+
elif i==1:
|
| 36 |
+
self.curac = ac
|
| 37 |
+
if self.supports_want_render and i<self.n-1:
|
| 38 |
+
ob, rew, done, info = self.env.step(self.curac, want_render=False)
|
| 39 |
+
else:
|
| 40 |
+
ob, rew, done, info = self.env.step(self.curac)
|
| 41 |
+
totrew += rew
|
| 42 |
+
if done: break
|
| 43 |
+
return ob, totrew, done, info
|
| 44 |
+
|
| 45 |
+
def seed(self, s):
|
| 46 |
+
self.rng.seed(s)
|
| 47 |
+
|
| 48 |
+
class PartialFrameStack(gym.Wrapper):
|
| 49 |
+
def __init__(self, env, k, channel=1):
|
| 50 |
+
"""
|
| 51 |
+
Stack one channel (channel keyword) from previous frames
|
| 52 |
+
"""
|
| 53 |
+
gym.Wrapper.__init__(self, env)
|
| 54 |
+
shp = env.observation_space.shape
|
| 55 |
+
self.channel = channel
|
| 56 |
+
self.observation_space = gym.spaces.Box(low=0, high=255,
|
| 57 |
+
shape=(shp[0], shp[1], shp[2] + k - 1),
|
| 58 |
+
dtype=env.observation_space.dtype)
|
| 59 |
+
self.k = k
|
| 60 |
+
self.frames = deque([], maxlen=k)
|
| 61 |
+
shp = env.observation_space.shape
|
| 62 |
+
|
| 63 |
+
def reset(self):
|
| 64 |
+
ob = self.env.reset()
|
| 65 |
+
assert ob.shape[2] > self.channel
|
| 66 |
+
for _ in range(self.k):
|
| 67 |
+
self.frames.append(ob)
|
| 68 |
+
return self._get_ob()
|
| 69 |
+
|
| 70 |
+
def step(self, ac):
|
| 71 |
+
ob, reward, done, info = self.env.step(ac)
|
| 72 |
+
self.frames.append(ob)
|
| 73 |
+
return self._get_ob(), reward, done, info
|
| 74 |
+
|
| 75 |
+
def _get_ob(self):
|
| 76 |
+
assert len(self.frames) == self.k
|
| 77 |
+
return np.concatenate([frame if i==self.k-1 else frame[:,:,self.channel:self.channel+1]
|
| 78 |
+
for (i, frame) in enumerate(self.frames)], axis=2)
|
| 79 |
+
|
| 80 |
+
class Downsample(gym.ObservationWrapper):
|
| 81 |
+
def __init__(self, env, ratio):
|
| 82 |
+
"""
|
| 83 |
+
Downsample images by a factor of ratio
|
| 84 |
+
"""
|
| 85 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 86 |
+
(oldh, oldw, oldc) = env.observation_space.shape
|
| 87 |
+
newshape = (oldh//ratio, oldw//ratio, oldc)
|
| 88 |
+
self.observation_space = gym.spaces.Box(low=0, high=255,
|
| 89 |
+
shape=newshape, dtype=np.uint8)
|
| 90 |
+
|
| 91 |
+
def observation(self, frame):
|
| 92 |
+
height, width, _ = self.observation_space.shape
|
| 93 |
+
frame = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
|
| 94 |
+
if frame.ndim == 2:
|
| 95 |
+
frame = frame[:,:,None]
|
| 96 |
+
return frame
|
| 97 |
+
|
| 98 |
+
class Rgb2gray(gym.ObservationWrapper):
|
| 99 |
+
def __init__(self, env):
|
| 100 |
+
"""
|
| 101 |
+
Downsample images by a factor of ratio
|
| 102 |
+
"""
|
| 103 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 104 |
+
(oldh, oldw, _oldc) = env.observation_space.shape
|
| 105 |
+
self.observation_space = gym.spaces.Box(low=0, high=255,
|
| 106 |
+
shape=(oldh, oldw, 1), dtype=np.uint8)
|
| 107 |
+
|
| 108 |
+
def observation(self, frame):
|
| 109 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 110 |
+
return frame[:,:,None]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MovieRecord(gym.Wrapper):
|
| 114 |
+
def __init__(self, env, savedir, k):
|
| 115 |
+
gym.Wrapper.__init__(self, env)
|
| 116 |
+
self.savedir = savedir
|
| 117 |
+
self.k = k
|
| 118 |
+
self.epcount = 0
|
| 119 |
+
def reset(self):
|
| 120 |
+
if self.epcount % self.k == 0:
|
| 121 |
+
self.env.unwrapped.movie_path = self.savedir
|
| 122 |
+
else:
|
| 123 |
+
self.env.unwrapped.movie_path = None
|
| 124 |
+
self.env.unwrapped.movie = None
|
| 125 |
+
self.epcount += 1
|
| 126 |
+
return self.env.reset()
|
| 127 |
+
|
| 128 |
+
class AppendTimeout(gym.Wrapper):
|
| 129 |
+
def __init__(self, env):
|
| 130 |
+
gym.Wrapper.__init__(self, env)
|
| 131 |
+
self.action_space = env.action_space
|
| 132 |
+
self.timeout_space = gym.spaces.Box(low=np.array([0.0]), high=np.array([1.0]), dtype=np.float32)
|
| 133 |
+
self.original_os = env.observation_space
|
| 134 |
+
if isinstance(self.original_os, gym.spaces.Dict):
|
| 135 |
+
import copy
|
| 136 |
+
ordered_dict = copy.deepcopy(self.original_os.spaces)
|
| 137 |
+
ordered_dict['value_estimation_timeout'] = self.timeout_space
|
| 138 |
+
self.observation_space = gym.spaces.Dict(ordered_dict)
|
| 139 |
+
self.dict_mode = True
|
| 140 |
+
else:
|
| 141 |
+
self.observation_space = gym.spaces.Dict({
|
| 142 |
+
'original': self.original_os,
|
| 143 |
+
'value_estimation_timeout': self.timeout_space
|
| 144 |
+
})
|
| 145 |
+
self.dict_mode = False
|
| 146 |
+
self.ac_count = None
|
| 147 |
+
while 1:
|
| 148 |
+
if not hasattr(env, "_max_episode_steps"): # Looking for TimeLimit wrapper that has this field
|
| 149 |
+
env = env.env
|
| 150 |
+
continue
|
| 151 |
+
break
|
| 152 |
+
self.timeout = env._max_episode_steps
|
| 153 |
+
|
| 154 |
+
def step(self, ac):
|
| 155 |
+
self.ac_count += 1
|
| 156 |
+
ob, rew, done, info = self.env.step(ac)
|
| 157 |
+
return self._process(ob), rew, done, info
|
| 158 |
+
|
| 159 |
+
def reset(self):
|
| 160 |
+
self.ac_count = 0
|
| 161 |
+
return self._process(self.env.reset())
|
| 162 |
+
|
| 163 |
+
def _process(self, ob):
|
| 164 |
+
fracmissing = 1 - self.ac_count / self.timeout
|
| 165 |
+
if self.dict_mode:
|
| 166 |
+
ob['value_estimation_timeout'] = fracmissing
|
| 167 |
+
else:
|
| 168 |
+
return { 'original': ob, 'value_estimation_timeout': fracmissing }
|
| 169 |
+
|
| 170 |
+
class StartDoingRandomActionsWrapper(gym.Wrapper):
|
| 171 |
+
"""
|
| 172 |
+
Warning: can eat info dicts, not good if you depend on them
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, env, max_random_steps, on_startup=True, every_episode=False):
|
| 175 |
+
gym.Wrapper.__init__(self, env)
|
| 176 |
+
self.on_startup = on_startup
|
| 177 |
+
self.every_episode = every_episode
|
| 178 |
+
self.random_steps = max_random_steps
|
| 179 |
+
self.last_obs = None
|
| 180 |
+
if on_startup:
|
| 181 |
+
self.some_random_steps()
|
| 182 |
+
|
| 183 |
+
def some_random_steps(self):
|
| 184 |
+
self.last_obs = self.env.reset()
|
| 185 |
+
n = np.random.randint(self.random_steps)
|
| 186 |
+
#print("running for random %i frames" % n)
|
| 187 |
+
for _ in range(n):
|
| 188 |
+
self.last_obs, _, done, _ = self.env.step(self.env.action_space.sample())
|
| 189 |
+
if done: self.last_obs = self.env.reset()
|
| 190 |
+
|
| 191 |
+
def reset(self):
|
| 192 |
+
return self.last_obs
|
| 193 |
+
|
| 194 |
+
def step(self, a):
|
| 195 |
+
self.last_obs, rew, done, info = self.env.step(a)
|
| 196 |
+
if done:
|
| 197 |
+
self.last_obs = self.env.reset()
|
| 198 |
+
if self.every_episode:
|
| 199 |
+
self.some_random_steps()
|
| 200 |
+
return self.last_obs, rew, done, info
|
| 201 |
+
|
| 202 |
+
def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
|
| 203 |
+
import retro
|
| 204 |
+
if state is None:
|
| 205 |
+
state = retro.State.DEFAULT
|
| 206 |
+
env = retro.make(game, state, **kwargs)
|
| 207 |
+
env = StochasticFrameSkip(env, n=4, stickprob=0.25)
|
| 208 |
+
if max_episode_steps is not None:
|
| 209 |
+
env = TimeLimit(env, max_episode_steps=max_episode_steps)
|
| 210 |
+
return env
|
| 211 |
+
|
| 212 |
+
def wrap_deepmind_retro(env, scale=True, frame_stack=4):
|
| 213 |
+
"""
|
| 214 |
+
Configure environment for retro games, using config similar to DeepMind-style Atari in wrap_deepmind
|
| 215 |
+
"""
|
| 216 |
+
env = WarpFrame(env)
|
| 217 |
+
env = ClipRewardEnv(env)
|
| 218 |
+
if frame_stack > 1:
|
| 219 |
+
env = FrameStack(env, frame_stack)
|
| 220 |
+
if scale:
|
| 221 |
+
env = ScaledFloatFrame(env)
|
| 222 |
+
return env
|
| 223 |
+
|
| 224 |
+
class SonicDiscretizer(gym.ActionWrapper):
|
| 225 |
+
"""
|
| 226 |
+
Wrap a gym-retro environment and make it use discrete
|
| 227 |
+
actions for the Sonic game.
|
| 228 |
+
"""
|
| 229 |
+
def __init__(self, env):
|
| 230 |
+
super(SonicDiscretizer, self).__init__(env)
|
| 231 |
+
buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]
|
| 232 |
+
actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
|
| 233 |
+
['DOWN', 'B'], ['B']]
|
| 234 |
+
self._actions = []
|
| 235 |
+
for action in actions:
|
| 236 |
+
arr = np.array([False] * 12)
|
| 237 |
+
for button in action:
|
| 238 |
+
arr[buttons.index(button)] = True
|
| 239 |
+
self._actions.append(arr)
|
| 240 |
+
self.action_space = gym.spaces.Discrete(len(self._actions))
|
| 241 |
+
|
| 242 |
+
def action(self, a): # pylint: disable=W0221
|
| 243 |
+
return self._actions[a].copy()
|
| 244 |
+
|
| 245 |
+
class RewardScaler(gym.RewardWrapper):
|
| 246 |
+
"""
|
| 247 |
+
Bring rewards to a reasonable scale for PPO.
|
| 248 |
+
This is incredibly important and effects performance
|
| 249 |
+
drastically.
|
| 250 |
+
"""
|
| 251 |
+
def __init__(self, env, scale=0.01):
|
| 252 |
+
super(RewardScaler, self).__init__(env)
|
| 253 |
+
self.scale = scale
|
| 254 |
+
|
| 255 |
+
def reward(self, reward):
|
| 256 |
+
return reward * self.scale
|
| 257 |
+
|
| 258 |
+
class AllowBacktracking(gym.Wrapper):
|
| 259 |
+
"""
|
| 260 |
+
Use deltas in max(X) as the reward, rather than deltas
|
| 261 |
+
in X. This way, agents are not discouraged too heavily
|
| 262 |
+
from exploring backwards if there is no way to advance
|
| 263 |
+
head-on in the level.
|
| 264 |
+
"""
|
| 265 |
+
def __init__(self, env):
|
| 266 |
+
super(AllowBacktracking, self).__init__(env)
|
| 267 |
+
self._cur_x = 0
|
| 268 |
+
self._max_x = 0
|
| 269 |
+
|
| 270 |
+
def reset(self, **kwargs): # pylint: disable=E0202
|
| 271 |
+
self._cur_x = 0
|
| 272 |
+
self._max_x = 0
|
| 273 |
+
return self.env.reset(**kwargs)
|
| 274 |
+
|
| 275 |
+
def step(self, action): # pylint: disable=E0202
|
| 276 |
+
obs, rew, done, info = self.env.step(action)
|
| 277 |
+
self._cur_x += rew
|
| 278 |
+
rew = max(0, self._cur_x - self._max_x)
|
| 279 |
+
self._max_x = max(self._max_x, self._cur_x)
|
| 280 |
+
return obs, rew, done, info
|
baselines/common/runners.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
|
| 4 |
+
class AbstractEnvRunner(ABC):
|
| 5 |
+
def __init__(self, *, env, model, nsteps):
|
| 6 |
+
self.env = env
|
| 7 |
+
self.model = model
|
| 8 |
+
self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
|
| 9 |
+
self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
|
| 10 |
+
self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
|
| 11 |
+
self.obs[:] = env.reset()
|
| 12 |
+
self.nsteps = nsteps
|
| 13 |
+
self.states = model.initial_state
|
| 14 |
+
self.dones = [False for _ in range(nenv)]
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def run(self):
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
baselines/common/running_mean_std.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
from baselines.common.tf_util import get_session
|
| 4 |
+
|
| 5 |
+
class RunningMeanStd(object):
|
| 6 |
+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
| 7 |
+
def __init__(self, epsilon=1e-4, shape=()):
|
| 8 |
+
self.mean = np.zeros(shape, 'float64')
|
| 9 |
+
self.var = np.ones(shape, 'float64')
|
| 10 |
+
self.count = epsilon
|
| 11 |
+
|
| 12 |
+
def update(self, x):
|
| 13 |
+
batch_mean = np.mean(x, axis=0)
|
| 14 |
+
batch_var = np.var(x, axis=0)
|
| 15 |
+
batch_count = x.shape[0]
|
| 16 |
+
self.update_from_moments(batch_mean, batch_var, batch_count)
|
| 17 |
+
|
| 18 |
+
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
| 19 |
+
self.mean, self.var, self.count = update_mean_var_count_from_moments(
|
| 20 |
+
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
|
| 21 |
+
|
| 22 |
+
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
|
| 23 |
+
delta = batch_mean - mean
|
| 24 |
+
tot_count = count + batch_count
|
| 25 |
+
|
| 26 |
+
new_mean = mean + delta * batch_count / tot_count
|
| 27 |
+
m_a = var * count
|
| 28 |
+
m_b = batch_var * batch_count
|
| 29 |
+
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
|
| 30 |
+
new_var = M2 / tot_count
|
| 31 |
+
new_count = tot_count
|
| 32 |
+
|
| 33 |
+
return new_mean, new_var, new_count
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TfRunningMeanStd(object):
|
| 37 |
+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
| 38 |
+
'''
|
| 39 |
+
TensorFlow variables-based implmentation of computing running mean and std
|
| 40 |
+
Benefit of this implementation is that it can be saved / loaded together with the tensorflow model
|
| 41 |
+
'''
|
| 42 |
+
def __init__(self, epsilon=1e-4, shape=(), scope=''):
|
| 43 |
+
sess = get_session()
|
| 44 |
+
|
| 45 |
+
self._new_mean = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64)
|
| 46 |
+
self._new_var = tf.compat.v1.placeholder(shape=shape, dtype=tf.float64)
|
| 47 |
+
self._new_count = tf.compat.v1.placeholder(shape=(), dtype=tf.float64)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
|
| 51 |
+
self._mean = tf.compat.v1.get_variable('mean', initializer=np.zeros(shape, 'float64'), dtype=tf.float64)
|
| 52 |
+
self._var = tf.compat.v1.get_variable('std', initializer=np.ones(shape, 'float64'), dtype=tf.float64)
|
| 53 |
+
self._count = tf.compat.v1.get_variable('count', initializer=np.full((), epsilon, 'float64'), dtype=tf.float64)
|
| 54 |
+
|
| 55 |
+
self.update_ops = tf.group([
|
| 56 |
+
self._var.assign(self._new_var),
|
| 57 |
+
self._mean.assign(self._new_mean),
|
| 58 |
+
self._count.assign(self._new_count)
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
sess.run(tf.compat.v1.variables_initializer([self._mean, self._var, self._count]))
|
| 62 |
+
self.sess = sess
|
| 63 |
+
self._set_mean_var_count()
|
| 64 |
+
|
| 65 |
+
def _set_mean_var_count(self):
|
| 66 |
+
self.mean, self.var, self.count = self.sess.run([self._mean, self._var, self._count])
|
| 67 |
+
|
| 68 |
+
def update(self, x):
|
| 69 |
+
batch_mean = np.mean(x, axis=0)
|
| 70 |
+
batch_var = np.var(x, axis=0)
|
| 71 |
+
batch_count = x.shape[0]
|
| 72 |
+
|
| 73 |
+
new_mean, new_var, new_count = update_mean_var_count_from_moments(self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
|
| 74 |
+
|
| 75 |
+
self.sess.run(self.update_ops, feed_dict={
|
| 76 |
+
self._new_mean: new_mean,
|
| 77 |
+
self._new_var: new_var,
|
| 78 |
+
self._new_count: new_count
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
self._set_mean_var_count()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_runningmeanstd():
|
| 86 |
+
for (x1, x2, x3) in [
|
| 87 |
+
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
| 88 |
+
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
|
| 89 |
+
]:
|
| 90 |
+
|
| 91 |
+
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
|
| 92 |
+
|
| 93 |
+
x = np.concatenate([x1, x2, x3], axis=0)
|
| 94 |
+
ms1 = [x.mean(axis=0), x.var(axis=0)]
|
| 95 |
+
rms.update(x1)
|
| 96 |
+
rms.update(x2)
|
| 97 |
+
rms.update(x3)
|
| 98 |
+
ms2 = [rms.mean, rms.var]
|
| 99 |
+
|
| 100 |
+
np.testing.assert_allclose(ms1, ms2)
|
| 101 |
+
|
| 102 |
+
def test_tf_runningmeanstd():
|
| 103 |
+
for (x1, x2, x3) in [
|
| 104 |
+
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
| 105 |
+
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
|
| 106 |
+
]:
|
| 107 |
+
|
| 108 |
+
rms = TfRunningMeanStd(epsilon=0.0, shape=x1.shape[1:], scope='running_mean_std' + str(np.random.randint(0, 128)))
|
| 109 |
+
|
| 110 |
+
x = np.concatenate([x1, x2, x3], axis=0)
|
| 111 |
+
ms1 = [x.mean(axis=0), x.var(axis=0)]
|
| 112 |
+
rms.update(x1)
|
| 113 |
+
rms.update(x2)
|
| 114 |
+
rms.update(x3)
|
| 115 |
+
ms2 = [rms.mean, rms.var]
|
| 116 |
+
|
| 117 |
+
np.testing.assert_allclose(ms1, ms2)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def profile_tf_runningmeanstd():
|
| 121 |
+
import time
|
| 122 |
+
from baselines.common import tf_util
|
| 123 |
+
|
| 124 |
+
tf_util.get_session( config=tf.compat.v1.ConfigProto(
|
| 125 |
+
inter_op_parallelism_threads=1,
|
| 126 |
+
intra_op_parallelism_threads=1,
|
| 127 |
+
allow_soft_placement=True
|
| 128 |
+
))
|
| 129 |
+
|
| 130 |
+
x = np.random.random((376,))
|
| 131 |
+
|
| 132 |
+
n_trials = 10000
|
| 133 |
+
rms = RunningMeanStd()
|
| 134 |
+
tfrms = TfRunningMeanStd()
|
| 135 |
+
|
| 136 |
+
tic1 = time.time()
|
| 137 |
+
for _ in range(n_trials):
|
| 138 |
+
rms.update(x)
|
| 139 |
+
|
| 140 |
+
tic2 = time.time()
|
| 141 |
+
for _ in range(n_trials):
|
| 142 |
+
tfrms.update(x)
|
| 143 |
+
|
| 144 |
+
tic3 = time.time()
|
| 145 |
+
|
| 146 |
+
print('rms update time ({} trials): {} s'.format(n_trials, tic2 - tic1))
|
| 147 |
+
print('tfrms update time ({} trials): {} s'.format(n_trials, tic3 - tic2))
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
tic1 = time.time()
|
| 151 |
+
for _ in range(n_trials):
|
| 152 |
+
z1 = rms.mean
|
| 153 |
+
|
| 154 |
+
tic2 = time.time()
|
| 155 |
+
for _ in range(n_trials):
|
| 156 |
+
z2 = tfrms.mean
|
| 157 |
+
|
| 158 |
+
assert z1 == z2
|
| 159 |
+
|
| 160 |
+
tic3 = time.time()
|
| 161 |
+
|
| 162 |
+
print('rms get mean time ({} trials): {} s'.format(n_trials, tic2 - tic1))
|
| 163 |
+
print('tfrms get mean time ({} trials): {} s'.format(n_trials, tic3 - tic2))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
'''
|
| 168 |
+
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) #pylint: disable=E1101
|
| 169 |
+
run_metadata = tf.RunMetadata()
|
| 170 |
+
profile_opts = dict(options=options, run_metadata=run_metadata)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
from tensorflow.python.client import timeline
|
| 175 |
+
fetched_timeline = timeline.Timeline(run_metadata.step_stats) #pylint: disable=E1101
|
| 176 |
+
chrome_trace = fetched_timeline.generate_chrome_trace_format()
|
| 177 |
+
outfile = '/tmp/timeline.json'
|
| 178 |
+
with open(outfile, 'wt') as f:
|
| 179 |
+
f.write(chrome_trace)
|
| 180 |
+
print('Successfully saved profile to {}. Exiting.'.format(outfile))
|
| 181 |
+
exit(0)
|
| 182 |
+
'''
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
profile_tf_runningmeanstd()
|
baselines/common/schedules.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This file is used for specifying various schedules that evolve over
|
| 2 |
+
time throughout the execution of the algorithm, such as:
|
| 3 |
+
- learning rate for the optimizer
|
| 4 |
+
- exploration epsilon for the epsilon greedy exploration strategy
|
| 5 |
+
- beta parameter for beta parameter in prioritized replay
|
| 6 |
+
|
| 7 |
+
Each schedule has a function `value(t)` which returns the current value
|
| 8 |
+
of the parameter given the timestep t of the optimization procedure.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Schedule(object):
|
| 13 |
+
def value(self, t):
|
| 14 |
+
"""Value of the schedule at time t"""
|
| 15 |
+
raise NotImplementedError()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConstantSchedule(object):
|
| 19 |
+
def __init__(self, value):
|
| 20 |
+
"""Value remains constant over time.
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
value: float
|
| 25 |
+
Constant value of the schedule
|
| 26 |
+
"""
|
| 27 |
+
self._v = value
|
| 28 |
+
|
| 29 |
+
def value(self, t):
|
| 30 |
+
"""See Schedule.value"""
|
| 31 |
+
return self._v
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def linear_interpolation(l, r, alpha):
|
| 35 |
+
return l + alpha * (r - l)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PiecewiseSchedule(object):
|
| 39 |
+
def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
|
| 40 |
+
"""Piecewise schedule.
|
| 41 |
+
|
| 42 |
+
endpoints: [(int, int)]
|
| 43 |
+
list of pairs `(time, value)` meanining that schedule should output
|
| 44 |
+
`value` when `t==time`. All the values for time must be sorted in
|
| 45 |
+
an increasing order. When t is between two times, e.g. `(time_a, value_a)`
|
| 46 |
+
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
|
| 47 |
+
`interpolation(value_a, value_b, alpha)` where alpha is a fraction of
|
| 48 |
+
time passed between `time_a` and `time_b` for time `t`.
|
| 49 |
+
interpolation: lambda float, float, float: float
|
| 50 |
+
a function that takes value to the left and to the right of t according
|
| 51 |
+
to the `endpoints`. Alpha is the fraction of distance from left endpoint to
|
| 52 |
+
right endpoint that t has covered. See linear_interpolation for example.
|
| 53 |
+
outside_value: float
|
| 54 |
+
if the value is requested outside of all the intervals sepecified in
|
| 55 |
+
`endpoints` this value is returned. If None then AssertionError is
|
| 56 |
+
raised when outside value is requested.
|
| 57 |
+
"""
|
| 58 |
+
idxes = [e[0] for e in endpoints]
|
| 59 |
+
assert idxes == sorted(idxes)
|
| 60 |
+
self._interpolation = interpolation
|
| 61 |
+
self._outside_value = outside_value
|
| 62 |
+
self._endpoints = endpoints
|
| 63 |
+
|
| 64 |
+
def value(self, t):
|
| 65 |
+
"""See Schedule.value"""
|
| 66 |
+
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
|
| 67 |
+
if l_t <= t and t < r_t:
|
| 68 |
+
alpha = float(t - l_t) / (r_t - l_t)
|
| 69 |
+
return self._interpolation(l, r, alpha)
|
| 70 |
+
|
| 71 |
+
# t does not belong to any of the pieces, so doom.
|
| 72 |
+
assert self._outside_value is not None
|
| 73 |
+
return self._outside_value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class LinearSchedule(object):
|
| 77 |
+
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
|
| 78 |
+
"""Linear interpolation between initial_p and final_p over
|
| 79 |
+
schedule_timesteps. After this many timesteps pass final_p is
|
| 80 |
+
returned.
|
| 81 |
+
|
| 82 |
+
Parameters
|
| 83 |
+
----------
|
| 84 |
+
schedule_timesteps: int
|
| 85 |
+
Number of timesteps for which to linearly anneal initial_p
|
| 86 |
+
to final_p
|
| 87 |
+
initial_p: float
|
| 88 |
+
initial output value
|
| 89 |
+
final_p: float
|
| 90 |
+
final output value
|
| 91 |
+
"""
|
| 92 |
+
self.schedule_timesteps = schedule_timesteps
|
| 93 |
+
self.final_p = final_p
|
| 94 |
+
self.initial_p = initial_p
|
| 95 |
+
|
| 96 |
+
def value(self, t):
|
| 97 |
+
"""See Schedule.value"""
|
| 98 |
+
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
| 99 |
+
return self.initial_p + fraction * (self.final_p - self.initial_p)
|