Upload 36 files
Browse files- Makefile +30 -0
- README.md +358 -0
- environment.atari.yml +153 -0
- environment.procgen-v2.yml +135 -0
- environment.procgen.yml +135 -0
- requirements-v1.txt +76 -0
- requirements.txt +42 -0
- src/airstriker-genesis/__init__.py +0 -0
- src/airstriker-genesis/agent.py +400 -0
- src/airstriker-genesis/cartpole.py +353 -0
- src/airstriker-genesis/procgen_agent.py +400 -0
- src/airstriker-genesis/replay.py +66 -0
- src/airstriker-genesis/run-airstriker-ddqn.py +120 -0
- src/airstriker-genesis/run-airstriker-dqn.py +115 -0
- src/airstriker-genesis/run-cartpole.py +120 -0
- src/airstriker-genesis/test.py +405 -0
- src/airstriker-genesis/utils.py +22 -0
- src/airstriker-genesis/wrappers.py +213 -0
- src/lunar-lander/agent.py +1104 -0
- src/lunar-lander/params.py +12 -0
- src/lunar-lander/replay.py +67 -0
- src/lunar-lander/run-lunar-ddqn.py +45 -0
- src/lunar-lander/run-lunar-dqn.py +46 -0
- src/lunar-lander/run-lunar-dueling-ddqn.py +47 -0
- src/lunar-lander/run-lunar-dueling-dqn.py +46 -0
- src/lunar-lander/train.py +84 -0
- src/lunar-lander/wrappers.py +193 -0
- src/procgen/agent.py +664 -0
- src/procgen/run-starpilot-ddqn.py +45 -0
- src/procgen/run-starpilot-dqn.py +45 -0
- src/procgen/run-starpilot-dueling-ddqn.py +45 -0
- src/procgen/run-starpilot-dueling-dqn.py +45 -0
- src/procgen/test-procgen.py +12 -0
- src/procgen/train.py +48 -0
- src/procgen/wrappers.py +187 -0
- troubleshooting.md +37 -0
Makefile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: create-atari-env
|
| 2 |
+
create-atari-env: ## Creates conda environment
|
| 3 |
+
conda env create -f environment.atari-yml --force
|
| 4 |
+
|
| 5 |
+
.PHONY: create-procgen-env
|
| 6 |
+
create-procgen-env: ## Creates conda environment
|
| 7 |
+
conda env create -f environment.procgen.yml --force
|
| 8 |
+
|
| 9 |
+
.PHONY: setup-env
|
| 10 |
+
setup-env: ## Sets up conda environment
|
| 11 |
+
conda install pytorch torchvision numpy -c pytorch -y
|
| 12 |
+
pip install gym-retro
|
| 13 |
+
pip install "gym[atari]==0.21.0"
|
| 14 |
+
pip install importlib-metadata==4.13.0
|
| 15 |
+
|
| 16 |
+
.PHONY: run-air-dqn
|
| 17 |
+
run-air-dqn: ## Runs
|
| 18 |
+
python ./src/airstriker-genesis/run-airstriker-dqn.py
|
| 19 |
+
|
| 20 |
+
.PHONY: run-air-ddqn
|
| 21 |
+
run-air-ddqn: ## Runs
|
| 22 |
+
python ./src/airstriker-genesis/run-airstriker-ddqn.py
|
| 23 |
+
|
| 24 |
+
.PHONY: run-starpilot-dqn
|
| 25 |
+
run-starpilot-dqn: ## Runs
|
| 26 |
+
python ./src/procgen/run-starpilot-dqn.py
|
| 27 |
+
|
| 28 |
+
.PHONY: run-starpilot-ddqn
|
| 29 |
+
run-starpilot-ddqn: ## Runs
|
| 30 |
+
python ./src/procgen/run-starpilot-ddqn.py
|
README.md
CHANGED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# **Abstract**
|
| 2 |
+
|
| 3 |
+
On January 1, 2013, DeepMind published a paper called "Playing Atari
|
| 4 |
+
with Deep Reinforcement Learning" introducing their algorithm called
|
| 5 |
+
Deep Q-Network (DQN) which revolutionized the field of reinforcement
|
| 6 |
+
learning. For the first time they had brought together Deep Learning and
|
| 7 |
+
Q-learning and showed impressive results applying deep reinforcement
|
| 8 |
+
learning to Atari games with their agents performing at or over human
|
| 9 |
+
level expertise in almost all the games trained on.
|
| 10 |
+
A Deep Q-Network utilizes a deep neural network to estimate the q-values
|
| 11 |
+
for each action, allowing the policy to select the action with the
|
| 12 |
+
maximum q-values. This use of deep neural network to get q-values was
|
| 13 |
+
immensely superior to implementing q-table look-ups and widened the
|
| 14 |
+
applicability of q-learning to more complex reinforcement learning
|
| 15 |
+
environments.
|
| 16 |
+
While revolutionary, the original version of DQN had a few problems,
|
| 17 |
+
especially its slow/inefficient learning process. Over these past 9
|
| 18 |
+
years, a few improved versions of DQNs have become popular. This project
|
| 19 |
+
is an attempt to study the effectiveness of a few of these DQN flavors,
|
| 20 |
+
what problems they solve and compare their performance in the same
|
| 21 |
+
reinforcement learning environment.
|
| 22 |
+
|
| 23 |
+
# Deep Q-Networks and its flavors
|
| 24 |
+
|
| 25 |
+
- **Vanilla DQN**
|
| 26 |
+
|
| 27 |
+
The vanilla (original) DQN uses 2 neural networks: the **online**
|
| 28 |
+
network and the **target** network. The online network is the main
|
| 29 |
+
neural network that the agent uses to select the best action for a
|
| 30 |
+
given state. The target neural network is usually a copy of the
|
| 31 |
+
online network. It is used to get the "target" q-values for each
|
| 32 |
+
action for a particular state. i.e. During the learning phase, since
|
| 33 |
+
we don’t have actual ground truths for future q-values, these
|
| 34 |
+
q-values from the target network will be used as labels optimize the
|
| 35 |
+
network.
|
| 36 |
+
|
| 37 |
+
The target network calculates the target q-values by using the
|
| 38 |
+
following Bellman equation: \[\begin{aligned}
|
| 39 |
+
Q(s_t, a_t) =
|
| 40 |
+
r_{t+1} + \gamma \max _{a_{t+1} \in A} Q(s_{t+1}, a_{t+1})
|
| 41 |
+
\end{aligned}\] where,
|
| 42 |
+
\(Q(s_t, a_t)\) = The target q-value (ground truth) for a past
|
| 43 |
+
experience in the replay memory
|
| 44 |
+
|
| 45 |
+
\(r_{t+1}\)= The reward that was obtained for taking the chosen
|
| 46 |
+
action in that particular experience
|
| 47 |
+
|
| 48 |
+
\(\gamma\)= The discount factor for future rewards
|
| 49 |
+
|
| 50 |
+
\(Q(s_{t+1}, a_{t+1})\) = The q-value for best action (based on the
|
| 51 |
+
policy) for the next state for that particular experience
|
| 52 |
+
|
| 53 |
+
- **Double DQN**
|
| 54 |
+
|
| 55 |
+
One of the problems with vanilla DQN is the way it calculates its
|
| 56 |
+
target values (ground-truth). We can see from the bellman equation
|
| 57 |
+
above that the target network uses the **max** q-value directly in
|
| 58 |
+
the equation. This is found to almost always overestimate the
|
| 59 |
+
q-value because using the **max** function introduces the
|
| 60 |
+
maximization-bias to our estimates. Using max will give the largest
|
| 61 |
+
value even if that specific max value was an outlier, thus skewing
|
| 62 |
+
our estimates.
|
| 63 |
+
The Double DQN solves this problem by changing the original
|
| 64 |
+
algorithm to the following:
|
| 65 |
+
|
| 66 |
+
1. Instead of using the **max** function, first use the online
|
| 67 |
+
network to estimate the best action for the next state
|
| 68 |
+
|
| 69 |
+
2. Calculate target q-values for the next state for each possible
|
| 70 |
+
action using the target network
|
| 71 |
+
|
| 72 |
+
3. From the q-values calculated by the target network, use the
|
| 73 |
+
q-value of the action chosen in step 1.
|
| 74 |
+
|
| 75 |
+
This can be represented by the following equation: \[\begin{aligned}
|
| 76 |
+
Q(s_t, a_t) =
|
| 77 |
+
r_{t+1} + \gamma Q_{target}(s_{t+1}, a'_{t+1})
|
| 78 |
+
\end{aligned}\] where, \[\begin{aligned}
|
| 79 |
+
a'_{t+1} = argmax({Q_{online}(s_{t+1})})
|
| 80 |
+
\end{aligned}\]
|
| 81 |
+
|
| 82 |
+
- **Dueling DQN**
|
| 83 |
+
|
| 84 |
+
The Dueling DQN algorithm was an attempt to improve upon the
|
| 85 |
+
original DQN algorithm by changing the architecture of the neural
|
| 86 |
+
network used in Deep Q-learning. The Duelling DQN algorithm splits
|
| 87 |
+
the last layer of the DQN into to parts, a **value stream** and an
|
| 88 |
+
**advantage stream**, the outputs of which are aggregated in an
|
| 89 |
+
aggregating layer that gives the final q-value. One of the main
|
| 90 |
+
problems with the original DQN algorithm was that the difference in
|
| 91 |
+
Q-values for the actions were often very close. Thus, selecting the
|
| 92 |
+
action with the max q-value might always not be the best action to
|
| 93 |
+
take. The Dueling DQN attempts to mitigate this by using advantage,
|
| 94 |
+
which is a measure of how better an action is compared to other
|
| 95 |
+
actions for a given state. The value stream, on the other hand,
|
| 96 |
+
learns how good/bad it is to be in a specific state. eg. Moving
|
| 97 |
+
straight towards an obstacle in a racing game, being in the path of
|
| 98 |
+
a projectile in Space Invaders, etc. Instead of learning to predict
|
| 99 |
+
a single q-value, by separating into value and advantage streams
|
| 100 |
+
helps the network generalize better.
|
| 101 |
+
|
| 102 |
+

|
| 103 |
+
Fig: The Dueling DQN architecture (Image taken from the original
|
| 104 |
+
paper by Wang et al.)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
The q-value in a Dueling DQN architecture is given by
|
| 108 |
+
\[\begin{aligned}
|
| 109 |
+
Q(s_t, a_t) = V(s_t) + A(a)
|
| 110 |
+
\end{aligned}\] where,
|
| 111 |
+
V(s\_t) = The value of the current state (how advantageous it is to
|
| 112 |
+
be in that state)
|
| 113 |
+
|
| 114 |
+
A(a) =The advantage of taking action an a at that state
|
| 115 |
+
|
| 116 |
+
# About the project
|
| 117 |
+
|
| 118 |
+
My original goal for the project was to train an agent using DQN to
|
| 119 |
+
play **Airstriker Genesis**, a space shooting game and evaluate the
|
| 120 |
+
same agent’s performance on another similar game called
|
| 121 |
+
**Starpilot**. Unfortunately, I was unable to train a decent enough
|
| 122 |
+
agent in the first game, which made it meaningless to evaluate it’s
|
| 123 |
+
performance on yet another game.
|
| 124 |
+
|
| 125 |
+
Because I still want to do the original project some time in the
|
| 126 |
+
future, to prepare myself for that I thought it would be better to
|
| 127 |
+
first learn in-depth about how Deep Q-Networks work, what their
|
| 128 |
+
shortcomings are and how they can be improved. This, and for
|
| 129 |
+
time-constraint reasons, I have changed my project for this class to
|
| 130 |
+
a comparison of various DQN versions.
|
| 131 |
+
|
| 132 |
+
# Dataset
|
| 133 |
+
|
| 134 |
+
I used the excellent [Gym](https://github.com/openai/gym) library to
|
| 135 |
+
run my environment. A total of 9 agents, 1 in Airstriker Genesis, 4
|
| 136 |
+
in Starpilot and 4 in Lunar Lander were trained.
|
| 137 |
+
|
| 138 |
+
| **Game** | **Observation Space** | **Action Space** |
|
| 139 |
+
| :----------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 140 |
+
| Airstriker Genesis | RGB values of each pixel of the game screen (255, 255, 3) | Discrete(12) representing each of the buttons on the old Atari controllers. But since only three of those buttons were used in the game the action space was reduced to 3 during training. ( Left, Right, Fire ) |
|
| 141 |
+
| Starpilot | RGB values of each pixel of the game screen (64, 64, 3) | Discrete(15) representing each of the button combos ( Left, Right, Up, Down, Up + Right, Up + Left, Down + Right, Down + Left, W, A, S, D, Q, E, Do nothing ) |
|
| 142 |
+
| Lunar Lander | 8-dimensional vector: ( X-coordinate, Y-coordinate, Linear velocity in X, Linear Velocity in Y, Angle, Angular Velocity, Boolean (Leg 1 in contact with ground), Boolean (Leg 2 in contact with ground) ) | Discrete(4)( Do nothing, Fire left engine, Fire main engine, Fire right engine ) |
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
**Environment/Libraries**:
|
| 146 |
+
Miniconda, Python 3.9, Gym, Pyorch, Numpy, Tensorboard on my
|
| 147 |
+
personal Macbook Pro (M1)
|
| 148 |
+
|
| 149 |
+
# ML Methodology
|
| 150 |
+
|
| 151 |
+
Each agent was trained using DQN or one of its flavors. Each agent
|
| 152 |
+
for a particular game was trained with the same hyperparameters with
|
| 153 |
+
just the underlying algorithm different. The following metrics for
|
| 154 |
+
each agent were used for evaluation:
|
| 155 |
+
|
| 156 |
+
- **Epsilon value over each episode** Shows what the exploration
|
| 157 |
+
rate was at the end of each episode.
|
| 158 |
+
|
| 159 |
+
- **Average Q-value for the last 100 episodes** A measure of the
|
| 160 |
+
average q-value (for the action chosen) for the last 100
|
| 161 |
+
episodes.
|
| 162 |
+
|
| 163 |
+
- **Average length for the last 100 episodes** A measure of the
|
| 164 |
+
average number of steps taken in each episode
|
| 165 |
+
|
| 166 |
+
- **Average loss for the last 100 episodes** A measure of loss
|
| 167 |
+
during learning in the last 100 episodes (A Huber Loss was used)
|
| 168 |
+
|
| 169 |
+
- **Average reward for the last 100 episodes** A measure of the
|
| 170 |
+
average reward the agent accumulated over the last 100 episodes
|
| 171 |
+
|
| 172 |
+
## Preprocessing
|
| 173 |
+
|
| 174 |
+
For the Airstriker and the Starpilot games:
|
| 175 |
+
|
| 176 |
+
1. Changed each frame to grayscale
|
| 177 |
+
Since the color shouldn’t matter to the agent, I decided to
|
| 178 |
+
change the RGB image to grayscale
|
| 179 |
+
|
| 180 |
+
2. Changed observation space shape from (height, width, channels)
|
| 181 |
+
to (channels, height, width) to make it compatible with
|
| 182 |
+
Pytorch
|
| 183 |
+
Apparently Pytorch uses a different format than the direct
|
| 184 |
+
output of the gym environment. For this reason, I had to reshape
|
| 185 |
+
each observation to match Pytorch’s scheme (this took me a very
|
| 186 |
+
long time to figure out, but had an "Aha\!" moment when I
|
| 187 |
+
remember you saying something similar in class).
|
| 188 |
+
|
| 189 |
+
3. Framestacking
|
| 190 |
+
Instead of processing 1 frame at a time, process 4 frames at a
|
| 191 |
+
time. This is because just 1 frame is not enough information for
|
| 192 |
+
the agent to decide what action to take.
|
| 193 |
+
|
| 194 |
+
For Lunar Lander, since the reward changes are very drastic (sudden
|
| 195 |
+
+100, -100, +200) rewards, I experimented with Reward Clipping
|
| 196 |
+
(clipping the rewards to \[-1, 1\] range) but this didn’t seem to
|
| 197 |
+
make much difference in my agent’s performance.
|
| 198 |
+
|
| 199 |
+
# Results
|
| 200 |
+
|
| 201 |
+
- **Airstriker Genesis**
|
| 202 |
+
The loss went down until about 5200 episodes but after that it
|
| 203 |
+
stopped going down any further. Consequently the average reward the
|
| 204 |
+
agent accumulated over the last 100 episodes pretty much plateaued
|
| 205 |
+
after about 5000 episodes. On analysis, I noticed that my
|
| 206 |
+
exploration rate at the end of the 7000th episode was still about
|
| 207 |
+
0.65, which means that the agent was taking random actions more than
|
| 208 |
+
half of the time. On hindsight, I feel like I should have trained
|
| 209 |
+
more, at least until the epsilon value (exploration rate) completely
|
| 210 |
+
decayed to 5%.
|
| 211 |
+
  
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
- **Starpilot**
|
| 215 |
+
|
| 216 |
+
I trained DQN, Double DQN, Dueling DQN and Dueling Double DQN
|
| 217 |
+
versions for this game to compare the different algorithms.
|
| 218 |
+
From the graph of mean q-values, we can tell that the Vanilla DQN
|
| 219 |
+
versions indeed give high q-values, and their Double-DQN couterparts
|
| 220 |
+
give lower values, which makes me think that my implementation of
|
| 221 |
+
the Double DQN algorithm was OK. I had expected the agent to
|
| 222 |
+
accumulate higher rewards starting much earlier for the Double and
|
| 223 |
+
Dueling versions, but since the average rewards was almost similar
|
| 224 |
+
for all the agents, I could not notice any stark differences between
|
| 225 |
+
the performance of each agent.
|
| 226 |
+
|
| 227 |
+

|
| 228 |
+
|
| 229 |
+

|
| 230 |
+
|
| 231 |
+
| | |
|
| 232 |
+
| :------------------ | :------------------ |
|
| 233 |
+
|  |  |
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
- **Lunar Lander**
|
| 237 |
+
|
| 238 |
+
Since I did gain much insight from the agent in the Starpilot game,
|
| 239 |
+
I thought I was not training long enough. So I tried training the
|
| 240 |
+
same agents on Lunar Lander, which is a comparatively simpler game
|
| 241 |
+
with a smaller observation space and one that a DQN algorithm should
|
| 242 |
+
be able converge pretty quickly to (based on comments by other
|
| 243 |
+
people in the RL community).
|
| 244 |
+

|
| 245 |
+
|
| 246 |
+

|
| 247 |
+
|
| 248 |
+
| | |
|
| 249 |
+
| :------------------- | :------------------- |
|
| 250 |
+
|  |  |
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
The results for this were interesting. Although I did not find any
|
| 255 |
+
vast difference between the different variations of the DQN
|
| 256 |
+
algorithm, I found that the performance of my agent suddenly got
|
| 257 |
+
worse at around 300 episodes. Upon researching on why this may have
|
| 258 |
+
happened, I learned that DQN agents suffer from **catastrophic
|
| 259 |
+
forgetting** i.e. after training extensively, the network suddenly
|
| 260 |
+
forgets what it has learned in the past and the starts performing
|
| 261 |
+
worse. Initially, I thought this might have been the case, but since
|
| 262 |
+
I haven’t trained long enough, and because all models started
|
| 263 |
+
performing worse at almost exactly the same episode number, I think
|
| 264 |
+
this might be a problem with my code or some hyperparameter that I
|
| 265 |
+
used.
|
| 266 |
+
|
| 267 |
+
Upon checking what the agent was doing in the actual game, I found
|
| 268 |
+
that it was playing it very safe and just constantly hovering in the
|
| 269 |
+
air, not attempting to land the spaceship (the goal of the agent is
|
| 270 |
+
to land within the yellow flags). I thought maybe penalizing the
|
| 271 |
+
rewards for taking too many steps in the episode would work, but
|
| 272 |
+
that didn’t help either.
|
| 273 |
+
|
| 274 |
+

|
| 275 |
+
|
| 276 |
+
# Problems Faced
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
Here are a few of the problems that I faced while training my agents:
|
| 280 |
+
|
| 281 |
+
- Understanding the various hyperparameters in the algorithm. DQN uses
|
| 282 |
+
a lot of moving parts and thus, tuning each parameter was a
|
| 283 |
+
difficult task. There were about 8 different hyperparameters (some
|
| 284 |
+
correlated) that impacted the agent’s training performance. I
|
| 285 |
+
struggled with understanding how each parameter impacted the agent
|
| 286 |
+
and also with figuring out how to find optimal values for those. I
|
| 287 |
+
ended up tuning them by trial and error.
|
| 288 |
+
|
| 289 |
+
- I got stuck for a long time figuring out why my convolutional layer
|
| 290 |
+
was not working. I didn’t realize that Pytorch has the channels in
|
| 291 |
+
the first dimension, and because of that, I was passing huge numbers
|
| 292 |
+
like 255 (the height of the image) into the input dimension for a
|
| 293 |
+
Conv2D layer.
|
| 294 |
+
|
| 295 |
+
- I struggled with knowing how long is long enough to realize that a
|
| 296 |
+
model is not working. I trained a model on Airstriker Genesis for 14
|
| 297 |
+
hours just to realize later that I had set a parameter incorrectly
|
| 298 |
+
and had to retrain all over again.
|
| 299 |
+
|
| 300 |
+
# What Next?
|
| 301 |
+
|
| 302 |
+
Although I didn’t get a final working agent for any of the games I
|
| 303 |
+
tried, I feel like I have learned a lot about reinforcement learning,
|
| 304 |
+
especially about Deep Q-learning. I plan to improve upon this further,
|
| 305 |
+
and hopefully get an agent to go far into at least one of the games.
|
| 306 |
+
Next time, I will start with first debugging my current code and see if
|
| 307 |
+
I have any implementation mistakes. Then I will train them a lot longer
|
| 308 |
+
than I did this time and see if it works. While learning about the
|
| 309 |
+
different flavors of DQN, I also learned a little about NoisyNet DQN,
|
| 310 |
+
Rainbow-DQN and Prioritized Experience Replay. I couln’t implement these
|
| 311 |
+
for this project, but I would like to try them out some time soon.
|
| 312 |
+
|
| 313 |
+
# Lessons Learned
|
| 314 |
+
|
| 315 |
+
- Reinforcement learning is a very challenging problem. It takes a
|
| 316 |
+
substantially large amount of time to train, it is hard to debug and
|
| 317 |
+
it is very difficult to tune its hyperparameters just right. It is a
|
| 318 |
+
lot different from supervised learning in that there are no actual
|
| 319 |
+
labels and thus, this makes optimization very difficult.
|
| 320 |
+
|
| 321 |
+
- I tried training an agent on the Atari Airstriker Genesis and the
|
| 322 |
+
procgen Starpilot game using just the CPU, but this took a very long
|
| 323 |
+
time. This is understandable because the inputs are images and using
|
| 324 |
+
a GPU would have been obviously better. Next time, I will definitely
|
| 325 |
+
try using a GPU to make training faster.
|
| 326 |
+
|
| 327 |
+
- Upon being faced with the problem of my agent not learning, I went
|
| 328 |
+
into research mode and got to learn a lot about DQN and its improved
|
| 329 |
+
versions. I am not a master of the algorithms yet (I have yet to get
|
| 330 |
+
an agent to perform well in the game), but I feel like I understand
|
| 331 |
+
how each version works.
|
| 332 |
+
|
| 333 |
+
- Rather than just following someone’s tutorial, also reading the
|
| 334 |
+
actual papers for that particular algorithm helped me understand the
|
| 335 |
+
algorithm better and code it.
|
| 336 |
+
|
| 337 |
+
- Doing this project reinforced into me that I love the concept of
|
| 338 |
+
reinforcement learning. It has made me even more interested into
|
| 339 |
+
exploring the field further and learn more.
|
| 340 |
+
|
| 341 |
+
# References / Resources
|
| 342 |
+
|
| 343 |
+
- [Reinforcement Learning (DQN) Tutorial, Adam
|
| 344 |
+
Paszke](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
|
| 345 |
+
|
| 346 |
+
- [Train a mario-playing RL agent, Yuansong Feng, Suraj Subramanian,
|
| 347 |
+
Howard Wang, Steven
|
| 348 |
+
Guo](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
|
| 349 |
+
|
| 350 |
+
- [About Double DQN, Dueling
|
| 351 |
+
DQN](https://horomary.hatenablog.com/entry/2021/02/06/013412)
|
| 352 |
+
|
| 353 |
+
- [Dueling Network Architecture for Deep Reinforcement Learning (Wang
|
| 354 |
+
et al., 2015))](https://arxiv.org/abs/1511.06581)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
*(Final source code for the project can be found*
|
| 358 |
+
[*here*](https://github.com/00ber/ml-reinforcement-learning)*)*.
|
environment.atari.yml
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mlrl
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- absl-py=1.3.0=py37hecd8cb5_0
|
| 7 |
+
- aiohttp=3.8.3=py37h6c40b1e_0
|
| 8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
| 9 |
+
- appnope=0.1.2=py37hecd8cb5_1001
|
| 10 |
+
- async-timeout=4.0.2=py37hecd8cb5_0
|
| 11 |
+
- asynctest=0.13.0=py_0
|
| 12 |
+
- attrs=22.1.0=py37hecd8cb5_0
|
| 13 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
| 14 |
+
- blas=1.0=mkl
|
| 15 |
+
- blinker=1.4=py37hecd8cb5_0
|
| 16 |
+
- brotli=1.0.9=hca72f7f_7
|
| 17 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
| 18 |
+
- brotlipy=0.7.0=py37h9ed2024_1003
|
| 19 |
+
- bzip2=1.0.8=h1de35cc_0
|
| 20 |
+
- c-ares=1.18.1=hca72f7f_0
|
| 21 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
| 22 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
| 23 |
+
- cairo=1.14.12=hc4e6be7_4
|
| 24 |
+
- certifi=2022.9.24=py37hecd8cb5_0
|
| 25 |
+
- cffi=1.15.0=py37hca72f7f_0
|
| 26 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 27 |
+
- click=8.0.4=py37hecd8cb5_0
|
| 28 |
+
- cryptography=38.0.1=py37hf6deb26_0
|
| 29 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 30 |
+
- dataclasses=0.8=pyh6d0b6a4_7
|
| 31 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
| 32 |
+
- expat=2.4.9=he9d5cce_0
|
| 33 |
+
- ffmpeg=4.0=h01ea3c9_0
|
| 34 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
| 35 |
+
- fontconfig=2.14.1=hedf32ac_1
|
| 36 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
| 37 |
+
- freetype=2.12.1=hd8bbffd_0
|
| 38 |
+
- frozenlist=1.3.3=py37h6c40b1e_0
|
| 39 |
+
- gettext=0.21.0=h7535e17_0
|
| 40 |
+
- giflib=5.2.1=haf1e3a3_0
|
| 41 |
+
- glib=2.63.1=hd977a24_0
|
| 42 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
| 43 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
| 44 |
+
- graphite2=1.3.14=he9d5cce_1
|
| 45 |
+
- grpcio=1.42.0=py37ha29bfda_0
|
| 46 |
+
- harfbuzz=1.8.8=hb8d4a28_0
|
| 47 |
+
- hdf5=1.10.2=hfa1e0ec_1
|
| 48 |
+
- icu=58.2=h0a44026_3
|
| 49 |
+
- idna=3.4=py37hecd8cb5_0
|
| 50 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
| 51 |
+
- ipython=7.31.1=py37hecd8cb5_1
|
| 52 |
+
- jasper=2.0.14=h0129ec2_2
|
| 53 |
+
- jedi=0.18.1=py37hecd8cb5_1
|
| 54 |
+
- jpeg=9e=hca72f7f_0
|
| 55 |
+
- kiwisolver=1.4.2=py37he9d5cce_0
|
| 56 |
+
- lcms2=2.12=hf1fd2bf_0
|
| 57 |
+
- lerc=3.0=he9d5cce_0
|
| 58 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
| 59 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
| 60 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
| 61 |
+
- libcxx=14.0.6=h9765a3e_0
|
| 62 |
+
- libdeflate=1.8=h9ed2024_5
|
| 63 |
+
- libedit=3.1.20221030=h6c40b1e_0
|
| 64 |
+
- libffi=3.2.1=h0a44026_1007
|
| 65 |
+
- libgfortran=3.0.1=h93005f0_2
|
| 66 |
+
- libiconv=1.16=hca72f7f_2
|
| 67 |
+
- libopencv=3.4.2=h7c891bd_1
|
| 68 |
+
- libopus=1.3.1=h1de35cc_0
|
| 69 |
+
- libpng=1.6.37=ha441bb4_0
|
| 70 |
+
- libprotobuf=3.20.1=h8346a28_0
|
| 71 |
+
- libtiff=4.4.0=h2cd0358_2
|
| 72 |
+
- libvpx=1.7.0=h378b8a2_0
|
| 73 |
+
- libwebp=1.2.4=h56c3ce4_0
|
| 74 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
| 75 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
| 76 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
| 77 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
| 78 |
+
- markdown=3.3.4=py37hecd8cb5_0
|
| 79 |
+
- matplotlib=3.1.2=py37h9aa3819_0
|
| 80 |
+
- matplotlib-inline=0.1.6=py37hecd8cb5_0
|
| 81 |
+
- mkl=2021.4.0=hecd8cb5_637
|
| 82 |
+
- mkl-service=2.4.0=py37h9ed2024_0
|
| 83 |
+
- mkl_fft=1.3.1=py37h4ab4a9b_0
|
| 84 |
+
- mkl_random=1.2.2=py37hb2f4e1b_0
|
| 85 |
+
- multidict=6.0.2=py37hca72f7f_0
|
| 86 |
+
- munkres=1.1.4=py_0
|
| 87 |
+
- ncurses=6.3=hca72f7f_3
|
| 88 |
+
- numpy=1.21.5=py37h2e5f0a9_3
|
| 89 |
+
- numpy-base=1.21.5=py37h3b1a694_3
|
| 90 |
+
- oauthlib=3.2.1=py37hecd8cb5_0
|
| 91 |
+
- olefile=0.46=py37_0
|
| 92 |
+
- opencv=3.4.2=py37h6fd60c2_1
|
| 93 |
+
- openssl=1.1.1s=hca72f7f_0
|
| 94 |
+
- packaging=21.3=pyhd3eb1b0_0
|
| 95 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
| 96 |
+
- pcre=8.45=h23ab428_0
|
| 97 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
| 98 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
| 99 |
+
- pillow=6.1.0=py37hb68e598_0
|
| 100 |
+
- pip=22.3.1=py37hecd8cb5_0
|
| 101 |
+
- pixman=0.40.0=h9ed2024_1
|
| 102 |
+
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
| 103 |
+
- protobuf=3.20.1=py37he9d5cce_0
|
| 104 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
| 105 |
+
- py-opencv=3.4.2=py37h7c891bd_1
|
| 106 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
| 107 |
+
- pyasn1-modules=0.2.8=py_0
|
| 108 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 109 |
+
- pygments=2.11.2=pyhd3eb1b0_0
|
| 110 |
+
- pyjwt=2.4.0=py37hecd8cb5_0
|
| 111 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
| 112 |
+
- pyparsing=3.0.9=py37hecd8cb5_0
|
| 113 |
+
- pysocks=1.7.1=py37hecd8cb5_0
|
| 114 |
+
- python=3.7.3=h359304d_0
|
| 115 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 116 |
+
- pytorch=1.13.1=py3.7_0
|
| 117 |
+
- readline=7.0=h1de35cc_5
|
| 118 |
+
- requests=2.28.1=py37hecd8cb5_0
|
| 119 |
+
- requests-oauthlib=1.3.0=py_0
|
| 120 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
| 121 |
+
- setuptools=65.5.0=py37hecd8cb5_0
|
| 122 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 123 |
+
- sqlite=3.33.0=hffcf06c_0
|
| 124 |
+
- tensorboard=2.9.0=py37hecd8cb5_0
|
| 125 |
+
- tensorboard-data-server=0.6.1=py37h7242b5c_0
|
| 126 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 127 |
+
- tk=8.6.12=h5d9f67b_0
|
| 128 |
+
- torchvision=0.2.2=py_3
|
| 129 |
+
- tornado=6.2=py37hca72f7f_0
|
| 130 |
+
- tqdm=4.64.1=py37hecd8cb5_0
|
| 131 |
+
- traitlets=5.7.1=py37hecd8cb5_0
|
| 132 |
+
- typing-extensions=4.4.0=py37hecd8cb5_0
|
| 133 |
+
- typing_extensions=4.4.0=py37hecd8cb5_0
|
| 134 |
+
- urllib3=1.26.13=py37hecd8cb5_0
|
| 135 |
+
- wcwidth=0.2.5=pyhd3eb1b0_0
|
| 136 |
+
- werkzeug=2.0.3=pyhd3eb1b0_0
|
| 137 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
| 138 |
+
- xz=5.2.8=h6c40b1e_0
|
| 139 |
+
- yarl=1.8.1=py37hca72f7f_0
|
| 140 |
+
- zlib=1.2.13=h4dc903c_0
|
| 141 |
+
- zstd=1.5.2=hcb37349_0
|
| 142 |
+
- pip:
|
| 143 |
+
- ale-py==0.7.5
|
| 144 |
+
- cloudpickle==2.2.0
|
| 145 |
+
- gym==0.21.0
|
| 146 |
+
- gym-notices==0.0.8
|
| 147 |
+
- gym-retro==0.8.0
|
| 148 |
+
- importlib-metadata==4.13.0
|
| 149 |
+
- importlib-resources==5.10.1
|
| 150 |
+
- pygame==2.1.0
|
| 151 |
+
- pyglet==1.5.27
|
| 152 |
+
- zipp==3.11.0
|
| 153 |
+
prefix: /Users/karkisushant/miniconda3/envs/mlrl
|
environment.procgen-v2.yml
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: procgen
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- absl-py=1.3.0=py39hecd8cb5_0
|
| 7 |
+
- aiohttp=3.8.3=py39h6c40b1e_0
|
| 8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
| 9 |
+
- async-timeout=4.0.2=py39hecd8cb5_0
|
| 10 |
+
- attrs=22.1.0=py39hecd8cb5_0
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py39hecd8cb5_0
|
| 13 |
+
- brotli=1.0.9=hca72f7f_7
|
| 14 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
| 15 |
+
- brotlipy=0.7.0=py39h9ed2024_1003
|
| 16 |
+
- bzip2=1.0.8=h1de35cc_0
|
| 17 |
+
- c-ares=1.18.1=hca72f7f_0
|
| 18 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
| 19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
| 20 |
+
- certifi=2022.9.24=py39hecd8cb5_0
|
| 21 |
+
- cffi=1.15.1=py39h6c40b1e_3
|
| 22 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 23 |
+
- click=8.0.4=py39hecd8cb5_0
|
| 24 |
+
- contourpy=1.0.5=py39haf03e11_0
|
| 25 |
+
- cryptography=38.0.1=py39hf6deb26_0
|
| 26 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 27 |
+
- ffmpeg=4.3=h0a44026_0
|
| 28 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
| 29 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
| 30 |
+
- freetype=2.12.1=hd8bbffd_0
|
| 31 |
+
- frozenlist=1.3.3=py39h6c40b1e_0
|
| 32 |
+
- gettext=0.21.0=h7535e17_0
|
| 33 |
+
- giflib=5.2.1=haf1e3a3_0
|
| 34 |
+
- gmp=6.2.1=he9d5cce_3
|
| 35 |
+
- gnutls=3.6.15=hed9c0bf_0
|
| 36 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
| 37 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
| 38 |
+
- grpcio=1.42.0=py39ha29bfda_0
|
| 39 |
+
- icu=58.2=h0a44026_3
|
| 40 |
+
- idna=3.4=py39hecd8cb5_0
|
| 41 |
+
- importlib-metadata=4.11.3=py39hecd8cb5_0
|
| 42 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
| 43 |
+
- jpeg=9e=hca72f7f_0
|
| 44 |
+
- kiwisolver=1.4.2=py39he9d5cce_0
|
| 45 |
+
- lame=3.100=h1de35cc_0
|
| 46 |
+
- lcms2=2.12=hf1fd2bf_0
|
| 47 |
+
- lerc=3.0=he9d5cce_0
|
| 48 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
| 49 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
| 50 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
| 51 |
+
- libcxx=14.0.6=h9765a3e_0
|
| 52 |
+
- libdeflate=1.8=h9ed2024_5
|
| 53 |
+
- libffi=3.4.2=hecd8cb5_6
|
| 54 |
+
- libiconv=1.16=hca72f7f_2
|
| 55 |
+
- libidn2=2.3.2=h9ed2024_0
|
| 56 |
+
- libpng=1.6.37=ha441bb4_0
|
| 57 |
+
- libprotobuf=3.20.1=h8346a28_0
|
| 58 |
+
- libtasn1=4.16.0=h9ed2024_0
|
| 59 |
+
- libtiff=4.4.0=h2cd0358_2
|
| 60 |
+
- libunistring=0.9.10=h9ed2024_0
|
| 61 |
+
- libwebp=1.2.4=h56c3ce4_0
|
| 62 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
| 63 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
| 64 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
| 65 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
| 66 |
+
- markdown=3.3.4=py39hecd8cb5_0
|
| 67 |
+
- markupsafe=2.1.1=py39hca72f7f_0
|
| 68 |
+
- matplotlib=3.6.2=py39hecd8cb5_0
|
| 69 |
+
- matplotlib-base=3.6.2=py39h220de94_0
|
| 70 |
+
- mkl=2021.4.0=hecd8cb5_637
|
| 71 |
+
- mkl-service=2.4.0=py39h9ed2024_0
|
| 72 |
+
- mkl_fft=1.3.1=py39h4ab4a9b_0
|
| 73 |
+
- mkl_random=1.2.2=py39hb2f4e1b_0
|
| 74 |
+
- multidict=6.0.2=py39hca72f7f_0
|
| 75 |
+
- munkres=1.1.4=py_0
|
| 76 |
+
- ncurses=6.3=hca72f7f_3
|
| 77 |
+
- nettle=3.7.3=h230ac6f_1
|
| 78 |
+
- numpy=1.23.4=py39he696674_0
|
| 79 |
+
- numpy-base=1.23.4=py39h9cd3388_0
|
| 80 |
+
- oauthlib=3.2.1=py39hecd8cb5_0
|
| 81 |
+
- openh264=2.1.1=h8346a28_0
|
| 82 |
+
- openssl=1.1.1s=hca72f7f_0
|
| 83 |
+
- packaging=21.3=pyhd3eb1b0_0
|
| 84 |
+
- pillow=9.2.0=py39hde71d04_1
|
| 85 |
+
- pip=22.3.1=py39hecd8cb5_0
|
| 86 |
+
- protobuf=3.20.1=py39he9d5cce_0
|
| 87 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
| 88 |
+
- pyasn1-modules=0.2.8=py_0
|
| 89 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 90 |
+
- pyjwt=2.4.0=py39hecd8cb5_0
|
| 91 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
| 92 |
+
- pyparsing=3.0.9=py39hecd8cb5_0
|
| 93 |
+
- pysocks=1.7.1=py39hecd8cb5_0
|
| 94 |
+
- python=3.9.15=h218abb5_2
|
| 95 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 96 |
+
- pytorch=1.13.1=py3.9_0
|
| 97 |
+
- readline=8.2=hca72f7f_0
|
| 98 |
+
- requests=2.28.1=py39hecd8cb5_0
|
| 99 |
+
- requests-oauthlib=1.3.0=py_0
|
| 100 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
| 101 |
+
- setuptools=65.5.0=py39hecd8cb5_0
|
| 102 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 103 |
+
- sqlite=3.40.0=h880c91c_0
|
| 104 |
+
- tensorboard=2.9.0=py39hecd8cb5_0
|
| 105 |
+
- tensorboard-data-server=0.6.1=py39h7242b5c_0
|
| 106 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 107 |
+
- tk=8.6.12=h5d9f67b_0
|
| 108 |
+
- torchvision=0.14.1=py39_cpu
|
| 109 |
+
- tornado=6.2=py39hca72f7f_0
|
| 110 |
+
- tqdm=4.64.1=py39hecd8cb5_0
|
| 111 |
+
- typing_extensions=4.4.0=py39hecd8cb5_0
|
| 112 |
+
- tzdata=2022g=h04d1e81_0
|
| 113 |
+
- urllib3=1.26.13=py39hecd8cb5_0
|
| 114 |
+
- werkzeug=2.2.2=py39hecd8cb5_0
|
| 115 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
| 116 |
+
- xz=5.2.8=h6c40b1e_0
|
| 117 |
+
- yarl=1.8.1=py39hca72f7f_0
|
| 118 |
+
- zipp=3.8.0=py39hecd8cb5_0
|
| 119 |
+
- zlib=1.2.13=h4dc903c_0
|
| 120 |
+
- zstd=1.5.2=hcb37349_0
|
| 121 |
+
- pip:
|
| 122 |
+
- cloudpickle==2.2.0
|
| 123 |
+
- filelock==3.8.2
|
| 124 |
+
- glcontext==2.3.7
|
| 125 |
+
- glfw==1.12.0
|
| 126 |
+
- gym==0.21.0
|
| 127 |
+
- gym-notices==0.0.8
|
| 128 |
+
- gym3==0.3.3
|
| 129 |
+
- imageio==2.22.4
|
| 130 |
+
- imageio-ffmpeg==0.3.0
|
| 131 |
+
- moderngl==5.7.4
|
| 132 |
+
- opencv-python==4.6.0.66
|
| 133 |
+
- procgen==0.10.7
|
| 134 |
+
- pyglet==1.5.27
|
| 135 |
+
prefix: /Users/karkisushant/miniconda3/envs/v2
|
environment.procgen.yml
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: procgen
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- absl-py=1.3.0=py39hecd8cb5_0
|
| 7 |
+
- aiohttp=3.8.3=py39h6c40b1e_0
|
| 8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
| 9 |
+
- async-timeout=4.0.2=py39hecd8cb5_0
|
| 10 |
+
- attrs=22.1.0=py39hecd8cb5_0
|
| 11 |
+
- blas=1.0=mkl
|
| 12 |
+
- blinker=1.4=py39hecd8cb5_0
|
| 13 |
+
- brotli=1.0.9=hca72f7f_7
|
| 14 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
| 15 |
+
- brotlipy=0.7.0=py39h9ed2024_1003
|
| 16 |
+
- bzip2=1.0.8=h1de35cc_0
|
| 17 |
+
- c-ares=1.18.1=hca72f7f_0
|
| 18 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
| 19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
| 20 |
+
- certifi=2022.9.24=py39hecd8cb5_0
|
| 21 |
+
- cffi=1.15.1=py39h6c40b1e_3
|
| 22 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 23 |
+
- click=8.0.4=py39hecd8cb5_0
|
| 24 |
+
- contourpy=1.0.5=py39haf03e11_0
|
| 25 |
+
- cryptography=38.0.1=py39hf6deb26_0
|
| 26 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 27 |
+
- ffmpeg=4.3=h0a44026_0
|
| 28 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
| 29 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
| 30 |
+
- freetype=2.12.1=hd8bbffd_0
|
| 31 |
+
- frozenlist=1.3.3=py39h6c40b1e_0
|
| 32 |
+
- gettext=0.21.0=h7535e17_0
|
| 33 |
+
- giflib=5.2.1=haf1e3a3_0
|
| 34 |
+
- gmp=6.2.1=he9d5cce_3
|
| 35 |
+
- gnutls=3.6.15=hed9c0bf_0
|
| 36 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
| 37 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
| 38 |
+
- grpcio=1.42.0=py39ha29bfda_0
|
| 39 |
+
- icu=58.2=h0a44026_3
|
| 40 |
+
- idna=3.4=py39hecd8cb5_0
|
| 41 |
+
- importlib-metadata=4.11.3=py39hecd8cb5_0
|
| 42 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
| 43 |
+
- jpeg=9e=hca72f7f_0
|
| 44 |
+
- kiwisolver=1.4.2=py39he9d5cce_0
|
| 45 |
+
- lame=3.100=h1de35cc_0
|
| 46 |
+
- lcms2=2.12=hf1fd2bf_0
|
| 47 |
+
- lerc=3.0=he9d5cce_0
|
| 48 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
| 49 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
| 50 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
| 51 |
+
- libcxx=14.0.6=h9765a3e_0
|
| 52 |
+
- libdeflate=1.8=h9ed2024_5
|
| 53 |
+
- libffi=3.4.2=hecd8cb5_6
|
| 54 |
+
- libiconv=1.16=hca72f7f_2
|
| 55 |
+
- libidn2=2.3.2=h9ed2024_0
|
| 56 |
+
- libpng=1.6.37=ha441bb4_0
|
| 57 |
+
- libprotobuf=3.20.1=h8346a28_0
|
| 58 |
+
- libtasn1=4.16.0=h9ed2024_0
|
| 59 |
+
- libtiff=4.4.0=h2cd0358_2
|
| 60 |
+
- libunistring=0.9.10=h9ed2024_0
|
| 61 |
+
- libwebp=1.2.4=h56c3ce4_0
|
| 62 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
| 63 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
| 64 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
| 65 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
| 66 |
+
- markdown=3.3.4=py39hecd8cb5_0
|
| 67 |
+
- markupsafe=2.1.1=py39hca72f7f_0
|
| 68 |
+
- matplotlib=3.6.2=py39hecd8cb5_0
|
| 69 |
+
- matplotlib-base=3.6.2=py39h220de94_0
|
| 70 |
+
- mkl=2021.4.0=hecd8cb5_637
|
| 71 |
+
- mkl-service=2.4.0=py39h9ed2024_0
|
| 72 |
+
- mkl_fft=1.3.1=py39h4ab4a9b_0
|
| 73 |
+
- mkl_random=1.2.2=py39hb2f4e1b_0
|
| 74 |
+
- multidict=6.0.2=py39hca72f7f_0
|
| 75 |
+
- munkres=1.1.4=py_0
|
| 76 |
+
- ncurses=6.3=hca72f7f_3
|
| 77 |
+
- nettle=3.7.3=h230ac6f_1
|
| 78 |
+
- numpy=1.23.4=py39he696674_0
|
| 79 |
+
- numpy-base=1.23.4=py39h9cd3388_0
|
| 80 |
+
- oauthlib=3.2.1=py39hecd8cb5_0
|
| 81 |
+
- openh264=2.1.1=h8346a28_0
|
| 82 |
+
- openssl=1.1.1s=hca72f7f_0
|
| 83 |
+
- packaging=21.3=pyhd3eb1b0_0
|
| 84 |
+
- pillow=9.2.0=py39hde71d04_1
|
| 85 |
+
- pip=22.3.1=py39hecd8cb5_0
|
| 86 |
+
- protobuf=3.20.1=py39he9d5cce_0
|
| 87 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
| 88 |
+
- pyasn1-modules=0.2.8=py_0
|
| 89 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 90 |
+
- pyjwt=2.4.0=py39hecd8cb5_0
|
| 91 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
| 92 |
+
- pyparsing=3.0.9=py39hecd8cb5_0
|
| 93 |
+
- pysocks=1.7.1=py39hecd8cb5_0
|
| 94 |
+
- python=3.9.15=h218abb5_2
|
| 95 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 96 |
+
- pytorch=1.13.1=py3.9_0
|
| 97 |
+
- readline=8.2=hca72f7f_0
|
| 98 |
+
- requests=2.28.1=py39hecd8cb5_0
|
| 99 |
+
- requests-oauthlib=1.3.0=py_0
|
| 100 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
| 101 |
+
- setuptools=65.5.0=py39hecd8cb5_0
|
| 102 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 103 |
+
- sqlite=3.40.0=h880c91c_0
|
| 104 |
+
- tensorboard=2.9.0=py39hecd8cb5_0
|
| 105 |
+
- tensorboard-data-server=0.6.1=py39h7242b5c_0
|
| 106 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
| 107 |
+
- tk=8.6.12=h5d9f67b_0
|
| 108 |
+
- torchvision=0.14.1=py39_cpu
|
| 109 |
+
- tornado=6.2=py39hca72f7f_0
|
| 110 |
+
- tqdm=4.64.1=py39hecd8cb5_0
|
| 111 |
+
- typing_extensions=4.4.0=py39hecd8cb5_0
|
| 112 |
+
- tzdata=2022g=h04d1e81_0
|
| 113 |
+
- urllib3=1.26.13=py39hecd8cb5_0
|
| 114 |
+
- werkzeug=2.2.2=py39hecd8cb5_0
|
| 115 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
| 116 |
+
- xz=5.2.8=h6c40b1e_0
|
| 117 |
+
- yarl=1.8.1=py39hca72f7f_0
|
| 118 |
+
- zipp=3.8.0=py39hecd8cb5_0
|
| 119 |
+
- zlib=1.2.13=h4dc903c_0
|
| 120 |
+
- zstd=1.5.2=hcb37349_0
|
| 121 |
+
- pip:
|
| 122 |
+
- cloudpickle==2.2.0
|
| 123 |
+
- filelock==3.8.2
|
| 124 |
+
- glcontext==2.3.7
|
| 125 |
+
- glfw==1.12.0
|
| 126 |
+
- gym==0.21.0
|
| 127 |
+
- gym-notices==0.0.8
|
| 128 |
+
- gym3==0.3.3
|
| 129 |
+
- imageio==2.22.4
|
| 130 |
+
- imageio-ffmpeg==0.3.0
|
| 131 |
+
- moderngl==5.7.4
|
| 132 |
+
- opencv-python==4.6.0.66
|
| 133 |
+
- procgen==0.10.7
|
| 134 |
+
- pyglet==1.5.27
|
| 135 |
+
prefix: /Users/karkisushant/miniconda3/envs/procgen
|
requirements-v1.txt
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==1.3.0
|
| 2 |
+
ale-py==0.7.5
|
| 3 |
+
astunparse==1.6.3
|
| 4 |
+
attrs==22.1.0
|
| 5 |
+
box2d-py==2.3.5
|
| 6 |
+
cachetools==5.2.0
|
| 7 |
+
certifi==2022.12.7
|
| 8 |
+
cffi==1.15.1
|
| 9 |
+
charset-normalizer==2.1.1
|
| 10 |
+
cloudpickle==2.2.0
|
| 11 |
+
cycler==0.11.0
|
| 12 |
+
Cython==0.29.32
|
| 13 |
+
fasteners==0.18
|
| 14 |
+
flatbuffers==22.12.6
|
| 15 |
+
fonttools==4.38.0
|
| 16 |
+
future==0.18.2
|
| 17 |
+
gast==0.4.0
|
| 18 |
+
glfw==2.5.5
|
| 19 |
+
google-auth==2.15.0
|
| 20 |
+
google-auth-oauthlib==0.4.6
|
| 21 |
+
google-pasta==0.2.0
|
| 22 |
+
grpcio==1.51.1
|
| 23 |
+
gym==0.21.0
|
| 24 |
+
gym-notices==0.0.8
|
| 25 |
+
gym-retro==0.8.0
|
| 26 |
+
h5py==3.7.0
|
| 27 |
+
idna==3.4
|
| 28 |
+
imageio==2.22.4
|
| 29 |
+
importlib-metadata==4.13.0
|
| 30 |
+
importlib-resources==5.10.1
|
| 31 |
+
iniconfig==1.1.1
|
| 32 |
+
keras==2.11.0
|
| 33 |
+
kiwisolver==1.4.4
|
| 34 |
+
libclang==14.0.6
|
| 35 |
+
lz4==4.0.2
|
| 36 |
+
Markdown==3.4.1
|
| 37 |
+
MarkupSafe==2.1.1
|
| 38 |
+
matplotlib==3.5.3
|
| 39 |
+
mujoco==2.2.0
|
| 40 |
+
mujoco-py==2.1.2.14
|
| 41 |
+
numpy==1.21.6
|
| 42 |
+
oauthlib==3.2.2
|
| 43 |
+
opencv-python==4.6.0.66
|
| 44 |
+
opt-einsum==3.3.0
|
| 45 |
+
packaging==22.0
|
| 46 |
+
Pillow==9.3.0
|
| 47 |
+
pluggy==1.0.0
|
| 48 |
+
protobuf==3.19.6
|
| 49 |
+
py==1.11.0
|
| 50 |
+
pyasn1==0.4.8
|
| 51 |
+
pyasn1-modules==0.2.8
|
| 52 |
+
pycparser==2.21
|
| 53 |
+
pygame==2.1.0
|
| 54 |
+
pyglet==1.5.11
|
| 55 |
+
PyOpenGL==3.1.6
|
| 56 |
+
pyparsing==3.0.9
|
| 57 |
+
pytest==7.0.1
|
| 58 |
+
python-dateutil==2.8.2
|
| 59 |
+
requests==2.28.1
|
| 60 |
+
requests-oauthlib==1.3.1
|
| 61 |
+
rsa==4.9
|
| 62 |
+
six==1.16.0
|
| 63 |
+
swig==4.1.1
|
| 64 |
+
tensorboard==2.11.0
|
| 65 |
+
tensorboard-data-server==0.6.1
|
| 66 |
+
tensorboard-plugin-wit==1.8.1
|
| 67 |
+
tensorflow==2.11.0
|
| 68 |
+
tensorflow-estimator==2.11.0
|
| 69 |
+
tensorflow-io-gcs-filesystem==0.28.0
|
| 70 |
+
termcolor==2.1.1
|
| 71 |
+
tomli==2.0.1
|
| 72 |
+
typing_extensions==4.4.0
|
| 73 |
+
urllib3==1.26.13
|
| 74 |
+
Werkzeug==2.2.2
|
| 75 |
+
wrapt==1.14.1
|
| 76 |
+
zipp==3.11.0
|
requirements.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==1.3.0
|
| 2 |
+
ale-py==0.7.5
|
| 3 |
+
attrs==22.1.0
|
| 4 |
+
box2d-py==2.3.5
|
| 5 |
+
cffi==1.15.1
|
| 6 |
+
cloudpickle==2.2.0
|
| 7 |
+
cycler==0.11.0
|
| 8 |
+
Cython==0.29.32
|
| 9 |
+
fasteners==0.18
|
| 10 |
+
fonttools==4.38.0
|
| 11 |
+
future==0.18.2
|
| 12 |
+
glfw==2.5.5
|
| 13 |
+
gym==0.21.0
|
| 14 |
+
gym-notices==0.0.8
|
| 15 |
+
gym-retro==0.8.0
|
| 16 |
+
imageio==2.22.4
|
| 17 |
+
importlib-metadata==4.13.0
|
| 18 |
+
importlib-resources==5.10.1
|
| 19 |
+
iniconfig==1.1.1
|
| 20 |
+
kiwisolver==1.4.4
|
| 21 |
+
lz4==4.0.2
|
| 22 |
+
matplotlib==3.5.3
|
| 23 |
+
mujoco==2.2.0
|
| 24 |
+
mujoco-py==2.1.2.14
|
| 25 |
+
numpy==1.18.0
|
| 26 |
+
opencv-python==4.6.0.66
|
| 27 |
+
packaging==22.0
|
| 28 |
+
Pillow==9.3.0
|
| 29 |
+
pluggy==1.0.0
|
| 30 |
+
py==1.11.0
|
| 31 |
+
pycparser==2.21
|
| 32 |
+
pygame==2.1.0
|
| 33 |
+
pyglet==1.5.11
|
| 34 |
+
PyOpenGL==3.1.6
|
| 35 |
+
pyparsing==3.0.9
|
| 36 |
+
pytest==7.0.1
|
| 37 |
+
python-dateutil==2.8.2
|
| 38 |
+
six==1.16.0
|
| 39 |
+
swig==4.1.1
|
| 40 |
+
tomli==2.0.1
|
| 41 |
+
typing_extensions==4.4.0
|
| 42 |
+
zipp==3.11.0
|
src/airstriker-genesis/__init__.py
ADDED
|
File without changes
|
src/airstriker-genesis/agent.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import copy
|
| 6 |
+
import time, datetime
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from collections import deque
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
import pickle
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DQNet(nn.Module):
|
| 14 |
+
"""mini cnn structure
|
| 15 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, input_dim, output_dim):
|
| 19 |
+
super().__init__()
|
| 20 |
+
print("#################################")
|
| 21 |
+
print("#################################")
|
| 22 |
+
print(input_dim)
|
| 23 |
+
print(output_dim)
|
| 24 |
+
print("#################################")
|
| 25 |
+
print("#################################")
|
| 26 |
+
c, h, w = input_dim
|
| 27 |
+
|
| 28 |
+
# if h != 84:
|
| 29 |
+
# raise ValueError(f"Expecting input height: 84, got: {h}")
|
| 30 |
+
# if w != 84:
|
| 31 |
+
# raise ValueError(f"Expecting input width: 84, got: {w}")
|
| 32 |
+
|
| 33 |
+
self.online = nn.Sequential(
|
| 34 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Flatten(),
|
| 41 |
+
nn.Linear(17024, 512),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Linear(512, output_dim),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
self.target = copy.deepcopy(self.online)
|
| 48 |
+
|
| 49 |
+
# Q_target parameters are frozen.
|
| 50 |
+
for p in self.target.parameters():
|
| 51 |
+
p.requires_grad = False
|
| 52 |
+
|
| 53 |
+
def forward(self, input, model):
|
| 54 |
+
if model == "online":
|
| 55 |
+
return self.online(input)
|
| 56 |
+
elif model == "target":
|
| 57 |
+
return self.target(input)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MetricLogger:
|
| 62 |
+
def __init__(self, save_dir):
|
| 63 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
| 64 |
+
self.save_log = save_dir / "log"
|
| 65 |
+
with open(self.save_log, "w") as f:
|
| 66 |
+
f.write(
|
| 67 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
| 68 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
| 69 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
| 70 |
+
)
|
| 71 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
| 72 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
| 73 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
| 74 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
| 75 |
+
|
| 76 |
+
# History metrics
|
| 77 |
+
self.ep_rewards = []
|
| 78 |
+
self.ep_lengths = []
|
| 79 |
+
self.ep_avg_losses = []
|
| 80 |
+
self.ep_avg_qs = []
|
| 81 |
+
|
| 82 |
+
# Moving averages, added for every call to record()
|
| 83 |
+
self.moving_avg_ep_rewards = []
|
| 84 |
+
self.moving_avg_ep_lengths = []
|
| 85 |
+
self.moving_avg_ep_avg_losses = []
|
| 86 |
+
self.moving_avg_ep_avg_qs = []
|
| 87 |
+
|
| 88 |
+
# Current episode metric
|
| 89 |
+
self.init_episode()
|
| 90 |
+
|
| 91 |
+
# Timing
|
| 92 |
+
self.record_time = time.time()
|
| 93 |
+
|
| 94 |
+
def log_step(self, reward, loss, q):
|
| 95 |
+
self.curr_ep_reward += reward
|
| 96 |
+
self.curr_ep_length += 1
|
| 97 |
+
if loss:
|
| 98 |
+
self.curr_ep_loss += loss
|
| 99 |
+
self.curr_ep_q += q
|
| 100 |
+
self.curr_ep_loss_length += 1
|
| 101 |
+
|
| 102 |
+
def log_episode(self, episode_number):
|
| 103 |
+
"Mark end of episode"
|
| 104 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
| 105 |
+
self.ep_lengths.append(self.curr_ep_length)
|
| 106 |
+
if self.curr_ep_loss_length == 0:
|
| 107 |
+
ep_avg_loss = 0
|
| 108 |
+
ep_avg_q = 0
|
| 109 |
+
else:
|
| 110 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
| 111 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
| 112 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
| 113 |
+
self.ep_avg_qs.append(ep_avg_q)
|
| 114 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
| 115 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
| 116 |
+
self.writer.flush()
|
| 117 |
+
self.init_episode()
|
| 118 |
+
|
| 119 |
+
def init_episode(self):
|
| 120 |
+
self.curr_ep_reward = 0.0
|
| 121 |
+
self.curr_ep_length = 0
|
| 122 |
+
self.curr_ep_loss = 0.0
|
| 123 |
+
self.curr_ep_q = 0.0
|
| 124 |
+
self.curr_ep_loss_length = 0
|
| 125 |
+
|
| 126 |
+
def record(self, episode, epsilon, step):
|
| 127 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
| 128 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
| 129 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
| 130 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
| 131 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
| 132 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
| 133 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
| 134 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
| 135 |
+
|
| 136 |
+
last_record_time = self.record_time
|
| 137 |
+
self.record_time = time.time()
|
| 138 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
| 139 |
+
|
| 140 |
+
print(
|
| 141 |
+
f"Episode {episode} - "
|
| 142 |
+
f"Step {step} - "
|
| 143 |
+
f"Epsilon {epsilon} - "
|
| 144 |
+
f"Mean Reward {mean_ep_reward} - "
|
| 145 |
+
f"Mean Length {mean_ep_length} - "
|
| 146 |
+
f"Mean Loss {mean_ep_loss} - "
|
| 147 |
+
f"Mean Q Value {mean_ep_q} - "
|
| 148 |
+
f"Time Delta {time_since_last_record} - "
|
| 149 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
| 150 |
+
)
|
| 151 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 152 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
| 153 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
| 154 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 155 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
| 156 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
| 157 |
+
self.writer.flush()
|
| 158 |
+
with open(self.save_log, "a") as f:
|
| 159 |
+
f.write(
|
| 160 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
| 161 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
| 162 |
+
f"{time_since_last_record:15.3f}"
|
| 163 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
| 167 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
| 168 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
| 169 |
+
plt.clf()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class DQNAgent:
|
| 173 |
+
def __init__(self,
|
| 174 |
+
state_dim,
|
| 175 |
+
action_dim,
|
| 176 |
+
save_dir,
|
| 177 |
+
checkpoint=None,
|
| 178 |
+
learning_rate=0.00025,
|
| 179 |
+
max_memory_size=100000,
|
| 180 |
+
batch_size=32,
|
| 181 |
+
exploration_rate=1,
|
| 182 |
+
exploration_rate_decay=0.9999999,
|
| 183 |
+
exploration_rate_min=0.1,
|
| 184 |
+
training_frequency=1,
|
| 185 |
+
learning_starts=1000,
|
| 186 |
+
target_network_sync_frequency=500,
|
| 187 |
+
reset_exploration_rate=False,
|
| 188 |
+
save_frequency=100000,
|
| 189 |
+
gamma=0.9,
|
| 190 |
+
load_replay_buffer=True):
|
| 191 |
+
self.state_dim = state_dim
|
| 192 |
+
self.action_dim = action_dim
|
| 193 |
+
self.max_memory_size = max_memory_size
|
| 194 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 195 |
+
self.batch_size = batch_size
|
| 196 |
+
|
| 197 |
+
self.exploration_rate = exploration_rate
|
| 198 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 199 |
+
self.exploration_rate_min = exploration_rate_min
|
| 200 |
+
self.gamma = gamma
|
| 201 |
+
|
| 202 |
+
self.curr_step = 0
|
| 203 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 204 |
+
|
| 205 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 206 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 207 |
+
|
| 208 |
+
self.save_every = save_frequency # no. of experiences between saving Mario Net
|
| 209 |
+
self.save_dir = save_dir
|
| 210 |
+
|
| 211 |
+
self.use_cuda = torch.cuda.is_available()
|
| 212 |
+
|
| 213 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
| 214 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
| 215 |
+
if self.use_cuda:
|
| 216 |
+
self.net = self.net.to(device='cuda')
|
| 217 |
+
if checkpoint:
|
| 218 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 219 |
+
|
| 220 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
| 221 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def act(self, state):
|
| 225 |
+
"""
|
| 226 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 227 |
+
|
| 228 |
+
Inputs:
|
| 229 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 230 |
+
Outputs:
|
| 231 |
+
action_idx (int): An integer representing which action Mario will perform
|
| 232 |
+
"""
|
| 233 |
+
# EXPLORE
|
| 234 |
+
if np.random.rand() < self.exploration_rate:
|
| 235 |
+
action_idx = np.random.randint(self.action_dim)
|
| 236 |
+
|
| 237 |
+
# EXPLOIT
|
| 238 |
+
else:
|
| 239 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 240 |
+
state = state.unsqueeze(0)
|
| 241 |
+
action_values = self.net(state, model='online')
|
| 242 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 243 |
+
|
| 244 |
+
# decrease exploration_rate
|
| 245 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 246 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 247 |
+
|
| 248 |
+
# increment step
|
| 249 |
+
self.curr_step += 1
|
| 250 |
+
return action_idx
|
| 251 |
+
|
| 252 |
+
def cache(self, state, next_state, action, reward, done):
|
| 253 |
+
"""
|
| 254 |
+
Store the experience to self.memory (replay buffer)
|
| 255 |
+
|
| 256 |
+
Inputs:
|
| 257 |
+
state (LazyFrame),
|
| 258 |
+
next_state (LazyFrame),
|
| 259 |
+
action (int),
|
| 260 |
+
reward (float),
|
| 261 |
+
done(bool))
|
| 262 |
+
"""
|
| 263 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 264 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 265 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 266 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 267 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 268 |
+
|
| 269 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def recall(self):
|
| 273 |
+
"""
|
| 274 |
+
Retrieve a batch of experiences from memory
|
| 275 |
+
"""
|
| 276 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 277 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 278 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# def td_estimate(self, state, action):
|
| 282 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
| 283 |
+
# return current_Q
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# @torch.no_grad()
|
| 287 |
+
# def td_target(self, reward, next_state, done):
|
| 288 |
+
# next_state_Q = self.net(next_state, model='online')
|
| 289 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
| 290 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
| 291 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
| 292 |
+
|
| 293 |
+
def td_estimate(self, states, actions):
|
| 294 |
+
actions = actions.reshape(-1, 1)
|
| 295 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 296 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 297 |
+
return predicted_qs
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@torch.no_grad()
|
| 301 |
+
def td_target(self, rewards, next_states, dones):
|
| 302 |
+
rewards = rewards.reshape(-1, 1)
|
| 303 |
+
dones = dones.reshape(-1, 1)
|
| 304 |
+
target_qs = self.net(next_states, model='target')
|
| 305 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 306 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 307 |
+
target_qs[dones] = 0.0
|
| 308 |
+
return (rewards + (self.gamma * target_qs))
|
| 309 |
+
|
| 310 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 311 |
+
loss = self.loss_fn(td_estimate, td_target)
|
| 312 |
+
self.optimizer.zero_grad()
|
| 313 |
+
loss.backward()
|
| 314 |
+
self.optimizer.step()
|
| 315 |
+
return loss.item()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def sync_Q_target(self):
|
| 319 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def learn(self):
|
| 323 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 324 |
+
self.sync_Q_target()
|
| 325 |
+
|
| 326 |
+
if self.curr_step % self.save_every == 0:
|
| 327 |
+
self.save()
|
| 328 |
+
|
| 329 |
+
if self.curr_step < self.learning_starts:
|
| 330 |
+
return None, None
|
| 331 |
+
|
| 332 |
+
if self.curr_step % self.training_frequency != 0:
|
| 333 |
+
return None, None
|
| 334 |
+
|
| 335 |
+
# Sample from memory
|
| 336 |
+
state, next_state, action, reward, done = self.recall()
|
| 337 |
+
|
| 338 |
+
# Get TD Estimate
|
| 339 |
+
td_est = self.td_estimate(state, action)
|
| 340 |
+
|
| 341 |
+
# Get TD Target
|
| 342 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 343 |
+
|
| 344 |
+
# Backpropagate loss through Q_online
|
| 345 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 346 |
+
|
| 347 |
+
return (td_est.mean().item(), loss)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def save(self):
|
| 351 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 352 |
+
torch.save(
|
| 353 |
+
dict(
|
| 354 |
+
model=self.net.state_dict(),
|
| 355 |
+
exploration_rate=self.exploration_rate,
|
| 356 |
+
replay_memory=self.memory
|
| 357 |
+
),
|
| 358 |
+
save_path
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 365 |
+
if not load_path.exists():
|
| 366 |
+
raise ValueError(f"{load_path} does not exist")
|
| 367 |
+
|
| 368 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 369 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 370 |
+
state_dict = ckp.get('model')
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 374 |
+
self.net.load_state_dict(state_dict)
|
| 375 |
+
|
| 376 |
+
if load_replay_buffer:
|
| 377 |
+
replay_memory = ckp.get('replay_memory')
|
| 378 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 379 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 380 |
+
|
| 381 |
+
if reset_exploration_rate:
|
| 382 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 383 |
+
else:
|
| 384 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 385 |
+
self.exploration_rate = exploration_rate
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class DDQNAgent(DQNAgent):
|
| 389 |
+
@torch.no_grad()
|
| 390 |
+
def td_target(self, rewards, next_states, dones):
|
| 391 |
+
print("Double dqn -----------------------")
|
| 392 |
+
rewards = rewards.reshape(-1, 1)
|
| 393 |
+
dones = dones.reshape(-1, 1)
|
| 394 |
+
q_vals = self.net(next_states, model='online')
|
| 395 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 396 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 397 |
+
target_qs = self.net(next_states, model='target').gather(target_actions, 1)
|
| 398 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 399 |
+
target_qs[dones] = 0.0
|
| 400 |
+
return (rewards + (self.gamma * target_qs))
|
src/airstriker-genesis/cartpole.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import copy
|
| 7 |
+
import time, datetime
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from collections import deque
|
| 10 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 11 |
+
import pickle
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MyDQN(nn.Module):
|
| 15 |
+
"""mini cnn structure
|
| 16 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, input_dim, output_dim):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.online = nn.Sequential(
|
| 23 |
+
nn.Linear(input_dim, 128),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.Linear(128, 128),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.Linear(128, output_dim)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
self.target = copy.deepcopy(self.online)
|
| 32 |
+
|
| 33 |
+
# Q_target parameters are frozen.
|
| 34 |
+
for p in self.target.parameters():
|
| 35 |
+
p.requires_grad = False
|
| 36 |
+
|
| 37 |
+
def forward(self, input, model):
|
| 38 |
+
if model == "online":
|
| 39 |
+
return self.online(input)
|
| 40 |
+
elif model == "target":
|
| 41 |
+
return self.target(input)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class MetricLogger:
|
| 46 |
+
def __init__(self, save_dir):
|
| 47 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
| 48 |
+
self.save_log = save_dir / "log"
|
| 49 |
+
with open(self.save_log, "w") as f:
|
| 50 |
+
f.write(
|
| 51 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
| 52 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
| 53 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
| 54 |
+
)
|
| 55 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
| 56 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
| 57 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
| 58 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
| 59 |
+
|
| 60 |
+
# History metrics
|
| 61 |
+
self.ep_rewards = []
|
| 62 |
+
self.ep_lengths = []
|
| 63 |
+
self.ep_avg_losses = []
|
| 64 |
+
self.ep_avg_qs = []
|
| 65 |
+
|
| 66 |
+
# Moving averages, added for every call to record()
|
| 67 |
+
self.moving_avg_ep_rewards = []
|
| 68 |
+
self.moving_avg_ep_lengths = []
|
| 69 |
+
self.moving_avg_ep_avg_losses = []
|
| 70 |
+
self.moving_avg_ep_avg_qs = []
|
| 71 |
+
|
| 72 |
+
# Current episode metric
|
| 73 |
+
self.init_episode()
|
| 74 |
+
|
| 75 |
+
# Timing
|
| 76 |
+
self.record_time = time.time()
|
| 77 |
+
|
| 78 |
+
def log_step(self, reward, loss, q):
|
| 79 |
+
self.curr_ep_reward += reward
|
| 80 |
+
self.curr_ep_length += 1
|
| 81 |
+
if loss:
|
| 82 |
+
self.curr_ep_loss += loss
|
| 83 |
+
self.curr_ep_q += q
|
| 84 |
+
self.curr_ep_loss_length += 1
|
| 85 |
+
|
| 86 |
+
def log_episode(self, episode_number):
|
| 87 |
+
"Mark end of episode"
|
| 88 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
| 89 |
+
self.ep_lengths.append(self.curr_ep_length)
|
| 90 |
+
if self.curr_ep_loss_length == 0:
|
| 91 |
+
ep_avg_loss = 0
|
| 92 |
+
ep_avg_q = 0
|
| 93 |
+
else:
|
| 94 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
| 95 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
| 96 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
| 97 |
+
self.ep_avg_qs.append(ep_avg_q)
|
| 98 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
| 99 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
| 100 |
+
self.writer.flush()
|
| 101 |
+
self.init_episode()
|
| 102 |
+
|
| 103 |
+
def init_episode(self):
|
| 104 |
+
self.curr_ep_reward = 0.0
|
| 105 |
+
self.curr_ep_length = 0
|
| 106 |
+
self.curr_ep_loss = 0.0
|
| 107 |
+
self.curr_ep_q = 0.0
|
| 108 |
+
self.curr_ep_loss_length = 0
|
| 109 |
+
|
| 110 |
+
def record(self, episode, epsilon, step):
|
| 111 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
| 112 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
| 113 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
| 114 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
| 115 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
| 116 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
| 117 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
| 118 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
| 119 |
+
|
| 120 |
+
last_record_time = self.record_time
|
| 121 |
+
self.record_time = time.time()
|
| 122 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
| 123 |
+
|
| 124 |
+
print(
|
| 125 |
+
f"Episode {episode} - "
|
| 126 |
+
f"Step {step} - "
|
| 127 |
+
f"Epsilon {epsilon} - "
|
| 128 |
+
f"Mean Reward {mean_ep_reward} - "
|
| 129 |
+
f"Mean Length {mean_ep_length} - "
|
| 130 |
+
f"Mean Loss {mean_ep_loss} - "
|
| 131 |
+
f"Mean Q Value {mean_ep_q} - "
|
| 132 |
+
f"Time Delta {time_since_last_record} - "
|
| 133 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
| 134 |
+
)
|
| 135 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 136 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
| 137 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
| 138 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 139 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
| 140 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
| 141 |
+
self.writer.flush()
|
| 142 |
+
with open(self.save_log, "a") as f:
|
| 143 |
+
f.write(
|
| 144 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
| 145 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
| 146 |
+
f"{time_since_last_record:15.3f}"
|
| 147 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
| 151 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
| 152 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
| 153 |
+
plt.clf()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MyAgent:
|
| 157 |
+
def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, reset_exploration_rate=False, max_memory_size=100000):
|
| 158 |
+
self.state_dim = state_dim
|
| 159 |
+
self.action_dim = action_dim
|
| 160 |
+
self.max_memory_size = max_memory_size
|
| 161 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 162 |
+
# self.batch_size = 32
|
| 163 |
+
self.batch_size = 512
|
| 164 |
+
|
| 165 |
+
self.exploration_rate = 1
|
| 166 |
+
# self.exploration_rate_decay = 0.99999975
|
| 167 |
+
self.exploration_rate_decay = 0.9999999
|
| 168 |
+
self.exploration_rate_min = 0.1
|
| 169 |
+
self.gamma = 0.9
|
| 170 |
+
|
| 171 |
+
self.curr_step = 0
|
| 172 |
+
self.learning_start_threshold = 10000 # min. experiences before training
|
| 173 |
+
|
| 174 |
+
self.learn_every = 5 # no. of experiences between updates to Q_online
|
| 175 |
+
self.sync_every = 200 # no. of experiences between Q_target & Q_online sync
|
| 176 |
+
|
| 177 |
+
self.save_every = 200000 # no. of experiences between saving Mario Net
|
| 178 |
+
self.save_dir = save_dir
|
| 179 |
+
|
| 180 |
+
self.use_cuda = torch.cuda.is_available()
|
| 181 |
+
|
| 182 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
| 183 |
+
self.net = MyDQN(self.state_dim, self.action_dim).float()
|
| 184 |
+
if self.use_cuda:
|
| 185 |
+
self.net = self.net.to(device='cuda')
|
| 186 |
+
if checkpoint:
|
| 187 |
+
self.load(checkpoint, reset_exploration_rate)
|
| 188 |
+
|
| 189 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
|
| 190 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=0.00025, amsgrad=True)
|
| 191 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def act(self, state):
|
| 195 |
+
"""
|
| 196 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 197 |
+
|
| 198 |
+
Inputs:
|
| 199 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 200 |
+
Outputs:
|
| 201 |
+
action_idx (int): An integer representing which action Mario will perform
|
| 202 |
+
"""
|
| 203 |
+
# EXPLORE
|
| 204 |
+
if np.random.rand() < self.exploration_rate:
|
| 205 |
+
action_idx = np.random.randint(self.action_dim)
|
| 206 |
+
|
| 207 |
+
# EXPLOIT
|
| 208 |
+
else:
|
| 209 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 210 |
+
state = state.unsqueeze(0)
|
| 211 |
+
action_values = self.net(state, model='online')
|
| 212 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 213 |
+
|
| 214 |
+
# decrease exploration_rate
|
| 215 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 216 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 217 |
+
|
| 218 |
+
# increment step
|
| 219 |
+
self.curr_step += 1
|
| 220 |
+
return action_idx
|
| 221 |
+
|
| 222 |
+
def cache(self, state, next_state, action, reward, done):
|
| 223 |
+
"""
|
| 224 |
+
Store the experience to self.memory (replay buffer)
|
| 225 |
+
|
| 226 |
+
Inputs:
|
| 227 |
+
state (LazyFrame),
|
| 228 |
+
next_state (LazyFrame),
|
| 229 |
+
action (int),
|
| 230 |
+
reward (float),
|
| 231 |
+
done(bool))
|
| 232 |
+
"""
|
| 233 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 234 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 235 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 236 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 237 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 238 |
+
|
| 239 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def recall(self):
|
| 243 |
+
"""
|
| 244 |
+
Retrieve a batch of experiences from memory
|
| 245 |
+
"""
|
| 246 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 247 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 248 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# def td_estimate(self, state, action):
|
| 252 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
| 253 |
+
# return current_Q
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# @torch.no_grad()
|
| 257 |
+
# def td_target(self, reward, next_state, done):
|
| 258 |
+
# next_state_Q = self.net(next_state, model='online')
|
| 259 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
| 260 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
| 261 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
| 262 |
+
|
| 263 |
+
def td_estimate(self, states, actions):
|
| 264 |
+
actions = actions.reshape(-1, 1)
|
| 265 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 266 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 267 |
+
return predicted_qs
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@torch.no_grad()
|
| 271 |
+
def td_target(self, rewards, next_states, dones):
|
| 272 |
+
rewards = rewards.reshape(-1, 1)
|
| 273 |
+
dones = dones.reshape(-1, 1)
|
| 274 |
+
target_qs = self.net(next_states, model='target')
|
| 275 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 276 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 277 |
+
target_qs[dones] = 0.0
|
| 278 |
+
return (rewards + (self.gamma * target_qs))
|
| 279 |
+
|
| 280 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 281 |
+
loss = self.loss_fn(td_estimate, td_target)
|
| 282 |
+
self.optimizer.zero_grad()
|
| 283 |
+
loss.backward()
|
| 284 |
+
self.optimizer.step()
|
| 285 |
+
return loss.item()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def sync_Q_target(self):
|
| 289 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def learn(self):
|
| 293 |
+
if self.curr_step % self.sync_every == 0:
|
| 294 |
+
self.sync_Q_target()
|
| 295 |
+
|
| 296 |
+
if self.curr_step % self.save_every == 0:
|
| 297 |
+
self.save()
|
| 298 |
+
|
| 299 |
+
if self.curr_step < self.learning_start_threshold:
|
| 300 |
+
return None, None
|
| 301 |
+
|
| 302 |
+
if self.curr_step % self.learn_every != 0:
|
| 303 |
+
return None, None
|
| 304 |
+
|
| 305 |
+
# Sample from memory
|
| 306 |
+
state, next_state, action, reward, done = self.recall()
|
| 307 |
+
|
| 308 |
+
# Get TD Estimate
|
| 309 |
+
td_est = self.td_estimate(state, action)
|
| 310 |
+
|
| 311 |
+
# Get TD Target
|
| 312 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 313 |
+
|
| 314 |
+
# Backpropagate loss through Q_online
|
| 315 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 316 |
+
|
| 317 |
+
return (td_est.mean().item(), loss)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def save(self):
|
| 321 |
+
save_path = self.save_dir / f"cartpole_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 322 |
+
torch.save(
|
| 323 |
+
dict(
|
| 324 |
+
model=self.net.state_dict(),
|
| 325 |
+
exploration_rate=self.exploration_rate,
|
| 326 |
+
replay_memory=self.memory
|
| 327 |
+
),
|
| 328 |
+
save_path
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
print(f"Cartpole Net saved to {save_path} at step {self.curr_step}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def load(self, load_path, reset_exploration_rate=False):
|
| 335 |
+
if not load_path.exists():
|
| 336 |
+
raise ValueError(f"{load_path} does not exist")
|
| 337 |
+
|
| 338 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 339 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 340 |
+
state_dict = ckp.get('model')
|
| 341 |
+
replay_memory = ckp.get('replay_memory')
|
| 342 |
+
|
| 343 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 344 |
+
self.net.load_state_dict(state_dict)
|
| 345 |
+
|
| 346 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 347 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 348 |
+
|
| 349 |
+
if reset_exploration_rate:
|
| 350 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 351 |
+
else:
|
| 352 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 353 |
+
self.exploration_rate = exploration_rate
|
src/airstriker-genesis/procgen_agent.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import copy
|
| 6 |
+
import time, datetime
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from collections import deque
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
import pickle
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DQNet(nn.Module):
|
| 14 |
+
"""mini cnn structure
|
| 15 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, input_dim, output_dim):
|
| 19 |
+
super().__init__()
|
| 20 |
+
print("#################################")
|
| 21 |
+
print("#################################")
|
| 22 |
+
print(input_dim)
|
| 23 |
+
print(output_dim)
|
| 24 |
+
print("#################################")
|
| 25 |
+
print("#################################")
|
| 26 |
+
c, h, w = input_dim
|
| 27 |
+
|
| 28 |
+
# if h != 84:
|
| 29 |
+
# raise ValueError(f"Expecting input height: 84, got: {h}")
|
| 30 |
+
# if w != 84:
|
| 31 |
+
# raise ValueError(f"Expecting input width: 84, got: {w}")
|
| 32 |
+
|
| 33 |
+
self.online = nn.Sequential(
|
| 34 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
| 35 |
+
nn.ReLU(),
|
| 36 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Flatten(),
|
| 41 |
+
nn.Linear(7168, 512),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Linear(512, output_dim),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
self.target = copy.deepcopy(self.online)
|
| 48 |
+
|
| 49 |
+
# Q_target parameters are frozen.
|
| 50 |
+
for p in self.target.parameters():
|
| 51 |
+
p.requires_grad = False
|
| 52 |
+
|
| 53 |
+
def forward(self, input, model):
|
| 54 |
+
if model == "online":
|
| 55 |
+
return self.online(input)
|
| 56 |
+
elif model == "target":
|
| 57 |
+
return self.target(input)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MetricLogger:
|
| 62 |
+
def __init__(self, save_dir):
|
| 63 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
| 64 |
+
self.save_log = save_dir / "log"
|
| 65 |
+
with open(self.save_log, "w") as f:
|
| 66 |
+
f.write(
|
| 67 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
| 68 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
| 69 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
| 70 |
+
)
|
| 71 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
| 72 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
| 73 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
| 74 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
| 75 |
+
|
| 76 |
+
# History metrics
|
| 77 |
+
self.ep_rewards = []
|
| 78 |
+
self.ep_lengths = []
|
| 79 |
+
self.ep_avg_losses = []
|
| 80 |
+
self.ep_avg_qs = []
|
| 81 |
+
|
| 82 |
+
# Moving averages, added for every call to record()
|
| 83 |
+
self.moving_avg_ep_rewards = []
|
| 84 |
+
self.moving_avg_ep_lengths = []
|
| 85 |
+
self.moving_avg_ep_avg_losses = []
|
| 86 |
+
self.moving_avg_ep_avg_qs = []
|
| 87 |
+
|
| 88 |
+
# Current episode metric
|
| 89 |
+
self.init_episode()
|
| 90 |
+
|
| 91 |
+
# Timing
|
| 92 |
+
self.record_time = time.time()
|
| 93 |
+
|
| 94 |
+
def log_step(self, reward, loss, q):
|
| 95 |
+
self.curr_ep_reward += reward
|
| 96 |
+
self.curr_ep_length += 1
|
| 97 |
+
if loss:
|
| 98 |
+
self.curr_ep_loss += loss
|
| 99 |
+
self.curr_ep_q += q
|
| 100 |
+
self.curr_ep_loss_length += 1
|
| 101 |
+
|
| 102 |
+
def log_episode(self, episode_number):
|
| 103 |
+
"Mark end of episode"
|
| 104 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
| 105 |
+
self.ep_lengths.append(self.curr_ep_length)
|
| 106 |
+
if self.curr_ep_loss_length == 0:
|
| 107 |
+
ep_avg_loss = 0
|
| 108 |
+
ep_avg_q = 0
|
| 109 |
+
else:
|
| 110 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
| 111 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
| 112 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
| 113 |
+
self.ep_avg_qs.append(ep_avg_q)
|
| 114 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
| 115 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
| 116 |
+
self.writer.flush()
|
| 117 |
+
self.init_episode()
|
| 118 |
+
|
| 119 |
+
def init_episode(self):
|
| 120 |
+
self.curr_ep_reward = 0.0
|
| 121 |
+
self.curr_ep_length = 0
|
| 122 |
+
self.curr_ep_loss = 0.0
|
| 123 |
+
self.curr_ep_q = 0.0
|
| 124 |
+
self.curr_ep_loss_length = 0
|
| 125 |
+
|
| 126 |
+
def record(self, episode, epsilon, step):
|
| 127 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
| 128 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
| 129 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
| 130 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
| 131 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
| 132 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
| 133 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
| 134 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
| 135 |
+
|
| 136 |
+
last_record_time = self.record_time
|
| 137 |
+
self.record_time = time.time()
|
| 138 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
| 139 |
+
|
| 140 |
+
print(
|
| 141 |
+
f"Episode {episode} - "
|
| 142 |
+
f"Step {step} - "
|
| 143 |
+
f"Epsilon {epsilon} - "
|
| 144 |
+
f"Mean Reward {mean_ep_reward} - "
|
| 145 |
+
f"Mean Length {mean_ep_length} - "
|
| 146 |
+
f"Mean Loss {mean_ep_loss} - "
|
| 147 |
+
f"Mean Q Value {mean_ep_q} - "
|
| 148 |
+
f"Time Delta {time_since_last_record} - "
|
| 149 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
| 150 |
+
)
|
| 151 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 152 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
| 153 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
| 154 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 155 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
| 156 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
| 157 |
+
self.writer.flush()
|
| 158 |
+
with open(self.save_log, "a") as f:
|
| 159 |
+
f.write(
|
| 160 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
| 161 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
| 162 |
+
f"{time_since_last_record:15.3f}"
|
| 163 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
| 167 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
| 168 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
| 169 |
+
plt.clf()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class DQNAgent:
|
| 173 |
+
def __init__(self,
|
| 174 |
+
state_dim,
|
| 175 |
+
action_dim,
|
| 176 |
+
save_dir,
|
| 177 |
+
checkpoint=None,
|
| 178 |
+
learning_rate=0.00025,
|
| 179 |
+
max_memory_size=100000,
|
| 180 |
+
batch_size=32,
|
| 181 |
+
exploration_rate=1,
|
| 182 |
+
exploration_rate_decay=0.9999999,
|
| 183 |
+
exploration_rate_min=0.1,
|
| 184 |
+
training_frequency=1,
|
| 185 |
+
learning_starts=1000,
|
| 186 |
+
target_network_sync_frequency=500,
|
| 187 |
+
reset_exploration_rate=False,
|
| 188 |
+
save_frequency=100000,
|
| 189 |
+
gamma=0.9,
|
| 190 |
+
load_replay_buffer=True):
|
| 191 |
+
self.state_dim = state_dim
|
| 192 |
+
self.action_dim = action_dim
|
| 193 |
+
self.max_memory_size = max_memory_size
|
| 194 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 195 |
+
self.batch_size = batch_size
|
| 196 |
+
|
| 197 |
+
self.exploration_rate = exploration_rate
|
| 198 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 199 |
+
self.exploration_rate_min = exploration_rate_min
|
| 200 |
+
self.gamma = gamma
|
| 201 |
+
|
| 202 |
+
self.curr_step = 0
|
| 203 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 204 |
+
|
| 205 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 206 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 207 |
+
|
| 208 |
+
self.save_every = save_frequency # no. of experiences between saving Mario Net
|
| 209 |
+
self.save_dir = save_dir
|
| 210 |
+
|
| 211 |
+
self.use_cuda = torch.cuda.is_available()
|
| 212 |
+
|
| 213 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
| 214 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
| 215 |
+
if self.use_cuda:
|
| 216 |
+
self.net = self.net.to(device='cuda')
|
| 217 |
+
if checkpoint:
|
| 218 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 219 |
+
|
| 220 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
| 221 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def act(self, state):
|
| 225 |
+
"""
|
| 226 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 227 |
+
|
| 228 |
+
Inputs:
|
| 229 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 230 |
+
Outputs:
|
| 231 |
+
action_idx (int): An integer representing which action Mario will perform
|
| 232 |
+
"""
|
| 233 |
+
# EXPLORE
|
| 234 |
+
if np.random.rand() < self.exploration_rate:
|
| 235 |
+
action_idx = np.random.randint(self.action_dim)
|
| 236 |
+
|
| 237 |
+
# EXPLOIT
|
| 238 |
+
else:
|
| 239 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 240 |
+
state = state.unsqueeze(0)
|
| 241 |
+
action_values = self.net(state, model='online')
|
| 242 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 243 |
+
|
| 244 |
+
# decrease exploration_rate
|
| 245 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 246 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 247 |
+
|
| 248 |
+
# increment step
|
| 249 |
+
self.curr_step += 1
|
| 250 |
+
return action_idx
|
| 251 |
+
|
| 252 |
+
def cache(self, state, next_state, action, reward, done):
|
| 253 |
+
"""
|
| 254 |
+
Store the experience to self.memory (replay buffer)
|
| 255 |
+
|
| 256 |
+
Inputs:
|
| 257 |
+
state (LazyFrame),
|
| 258 |
+
next_state (LazyFrame),
|
| 259 |
+
action (int),
|
| 260 |
+
reward (float),
|
| 261 |
+
done(bool))
|
| 262 |
+
"""
|
| 263 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 264 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 265 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 266 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 267 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 268 |
+
|
| 269 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def recall(self):
|
| 273 |
+
"""
|
| 274 |
+
Retrieve a batch of experiences from memory
|
| 275 |
+
"""
|
| 276 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 277 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 278 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# def td_estimate(self, state, action):
|
| 282 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
| 283 |
+
# return current_Q
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# @torch.no_grad()
|
| 287 |
+
# def td_target(self, reward, next_state, done):
|
| 288 |
+
# next_state_Q = self.net(next_state, model='online')
|
| 289 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
| 290 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
| 291 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
| 292 |
+
|
| 293 |
+
def td_estimate(self, states, actions):
|
| 294 |
+
actions = actions.reshape(-1, 1)
|
| 295 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 296 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 297 |
+
return predicted_qs
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@torch.no_grad()
|
| 301 |
+
def td_target(self, rewards, next_states, dones):
|
| 302 |
+
rewards = rewards.reshape(-1, 1)
|
| 303 |
+
dones = dones.reshape(-1, 1)
|
| 304 |
+
target_qs = self.net(next_states, model='target')
|
| 305 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 306 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 307 |
+
target_qs[dones] = 0.0
|
| 308 |
+
return (rewards + (self.gamma * target_qs))
|
| 309 |
+
|
| 310 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 311 |
+
loss = self.loss_fn(td_estimate, td_target)
|
| 312 |
+
self.optimizer.zero_grad()
|
| 313 |
+
loss.backward()
|
| 314 |
+
self.optimizer.step()
|
| 315 |
+
return loss.item()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def sync_Q_target(self):
|
| 319 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def learn(self):
|
| 323 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 324 |
+
self.sync_Q_target()
|
| 325 |
+
|
| 326 |
+
if self.curr_step % self.save_every == 0:
|
| 327 |
+
self.save()
|
| 328 |
+
|
| 329 |
+
if self.curr_step < self.learning_starts:
|
| 330 |
+
return None, None
|
| 331 |
+
|
| 332 |
+
if self.curr_step % self.training_frequency != 0:
|
| 333 |
+
return None, None
|
| 334 |
+
|
| 335 |
+
# Sample from memory
|
| 336 |
+
state, next_state, action, reward, done = self.recall()
|
| 337 |
+
|
| 338 |
+
# Get TD Estimate
|
| 339 |
+
td_est = self.td_estimate(state, action)
|
| 340 |
+
|
| 341 |
+
# Get TD Target
|
| 342 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 343 |
+
|
| 344 |
+
# Backpropagate loss through Q_online
|
| 345 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 346 |
+
|
| 347 |
+
return (td_est.mean().item(), loss)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def save(self):
|
| 351 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 352 |
+
torch.save(
|
| 353 |
+
dict(
|
| 354 |
+
model=self.net.state_dict(),
|
| 355 |
+
exploration_rate=self.exploration_rate,
|
| 356 |
+
replay_memory=self.memory
|
| 357 |
+
),
|
| 358 |
+
save_path
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 365 |
+
if not load_path.exists():
|
| 366 |
+
raise ValueError(f"{load_path} does not exist")
|
| 367 |
+
|
| 368 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 369 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 370 |
+
state_dict = ckp.get('model')
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 374 |
+
self.net.load_state_dict(state_dict)
|
| 375 |
+
|
| 376 |
+
if load_replay_buffer:
|
| 377 |
+
replay_memory = ckp.get('replay_memory')
|
| 378 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 379 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 380 |
+
|
| 381 |
+
if reset_exploration_rate:
|
| 382 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 383 |
+
else:
|
| 384 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 385 |
+
self.exploration_rate = exploration_rate
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class DDQNAgent(DQNAgent):
|
| 389 |
+
@torch.no_grad()
|
| 390 |
+
def td_target(self, rewards, next_states, dones):
|
| 391 |
+
print("Double dqn -----------------------")
|
| 392 |
+
rewards = rewards.reshape(-1, 1)
|
| 393 |
+
dones = dones.reshape(-1, 1)
|
| 394 |
+
q_vals = self.net(next_states, model='online')
|
| 395 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 396 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 397 |
+
target_qs = self.net(next_states, model='target').gather(target_actions, 1)
|
| 398 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 399 |
+
target_qs[dones] = 0.0
|
| 400 |
+
return (rewards + (self.gamma * target_qs))
|
src/airstriker-genesis/replay.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from itertools import count
|
| 4 |
+
from agent import DQNAgent, MetricLogger
|
| 5 |
+
from wrappers import make_env, make_starpilot
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
env = make_starpilot()
|
| 9 |
+
|
| 10 |
+
env.reset()
|
| 11 |
+
|
| 12 |
+
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 13 |
+
save_dir.mkdir(parents=True)
|
| 14 |
+
|
| 15 |
+
checkpoint = Path('checkpoints/procgen-starpilot-dqn/airstriker_net_3.chkpt')
|
| 16 |
+
|
| 17 |
+
agent = DQNAgent(
|
| 18 |
+
state_dim=(1, 64, 64),
|
| 19 |
+
action_dim=env.action_space.n,
|
| 20 |
+
save_dir=save_dir,
|
| 21 |
+
batch_size=256,
|
| 22 |
+
checkpoint=checkpoint,
|
| 23 |
+
reset_exploration_rate=True,
|
| 24 |
+
exploration_rate_decay=0.999999,
|
| 25 |
+
training_frequency=10,
|
| 26 |
+
target_network_sync_frequency=200,
|
| 27 |
+
max_memory_size=3000,
|
| 28 |
+
learning_rate=0.001,
|
| 29 |
+
save_frequency=2000
|
| 30 |
+
|
| 31 |
+
)
|
| 32 |
+
agent.exploration_rate = agent.exploration_rate_min
|
| 33 |
+
|
| 34 |
+
# logger = MetricLogger(save_dir)
|
| 35 |
+
|
| 36 |
+
episodes = 100
|
| 37 |
+
|
| 38 |
+
for e in range(episodes):
|
| 39 |
+
|
| 40 |
+
state = env.reset()
|
| 41 |
+
|
| 42 |
+
while True:
|
| 43 |
+
|
| 44 |
+
env.render()
|
| 45 |
+
|
| 46 |
+
action = agent.act(state)
|
| 47 |
+
|
| 48 |
+
next_state, reward, done, info = env.step(action)
|
| 49 |
+
|
| 50 |
+
agent.cache(state, next_state, action, reward, done)
|
| 51 |
+
|
| 52 |
+
# logger.log_step(reward, None, None)
|
| 53 |
+
|
| 54 |
+
state = next_state
|
| 55 |
+
|
| 56 |
+
if done:
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
# logger.log_episode()
|
| 60 |
+
|
| 61 |
+
# if e % 20 == 0:
|
| 62 |
+
# logger.record(
|
| 63 |
+
# episode=e,
|
| 64 |
+
# epsilon=agent.exploration_rate,
|
| 65 |
+
# step=agent.curr_step
|
| 66 |
+
# )
|
src/airstriker-genesis/run-airstriker-ddqn.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import trange
|
| 8 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
| 9 |
+
from wrappers import make_env
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# set up matplotlib
|
| 13 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
| 14 |
+
if is_ipython:
|
| 15 |
+
from IPython import display
|
| 16 |
+
|
| 17 |
+
plt.ion()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
env = make_env()
|
| 21 |
+
|
| 22 |
+
use_cuda = torch.cuda.is_available()
|
| 23 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
checkpoint = None
|
| 27 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 28 |
+
|
| 29 |
+
path = "checkpoints/airstriker-ddqn"
|
| 30 |
+
save_dir = Path(path)
|
| 31 |
+
|
| 32 |
+
isExist = os.path.exists(path)
|
| 33 |
+
if not isExist:
|
| 34 |
+
os.makedirs(path)
|
| 35 |
+
|
| 36 |
+
# Vanilla DQN
|
| 37 |
+
print("Training Vanilla DQN Agent!")
|
| 38 |
+
# agent = DQNAgent(
|
| 39 |
+
# state_dim=(1, 84, 84),
|
| 40 |
+
# action_dim=env.action_space.n,
|
| 41 |
+
# save_dir=save_dir,
|
| 42 |
+
# batch_size=128,
|
| 43 |
+
# checkpoint=checkpoint,
|
| 44 |
+
# exploration_rate_decay=0.995,
|
| 45 |
+
# exploration_rate_min=0.05,
|
| 46 |
+
# training_frequency=1,
|
| 47 |
+
# target_network_sync_frequency=500,
|
| 48 |
+
# max_memory_size=50000,
|
| 49 |
+
# learning_rate=0.0005,
|
| 50 |
+
|
| 51 |
+
# )
|
| 52 |
+
|
| 53 |
+
# Double DQN
|
| 54 |
+
print("Training DDQN Agent!")
|
| 55 |
+
agent = DDQNAgent(
|
| 56 |
+
state_dim=(1, 84, 84),
|
| 57 |
+
action_dim=env.action_space.n,
|
| 58 |
+
save_dir=save_dir,
|
| 59 |
+
batch_size=128,
|
| 60 |
+
checkpoint=checkpoint,
|
| 61 |
+
exploration_rate_decay=0.995,
|
| 62 |
+
exploration_rate_min=0.05,
|
| 63 |
+
training_frequency=1,
|
| 64 |
+
target_network_sync_frequency=500,
|
| 65 |
+
max_memory_size=50000,
|
| 66 |
+
learning_rate=0.0005,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
logger = MetricLogger(save_dir)
|
| 70 |
+
|
| 71 |
+
def fill_memory(agent: DQNAgent, num_episodes=1000):
|
| 72 |
+
print("Filling up memory....")
|
| 73 |
+
for _ in trange(num_episodes):
|
| 74 |
+
state = env.reset()
|
| 75 |
+
done = False
|
| 76 |
+
while not done:
|
| 77 |
+
action = agent.act(state)
|
| 78 |
+
next_state, reward, done, _ = env.step(action)
|
| 79 |
+
agent.cache(state, next_state, action, reward, done)
|
| 80 |
+
state = next_state
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def train(agent: DQNAgent):
|
| 84 |
+
episodes = 10000000
|
| 85 |
+
for e in range(episodes):
|
| 86 |
+
|
| 87 |
+
state = env.reset()
|
| 88 |
+
# Play the game!
|
| 89 |
+
while True:
|
| 90 |
+
|
| 91 |
+
# print(state.shape)
|
| 92 |
+
# Run agent on the state
|
| 93 |
+
action = agent.act(state)
|
| 94 |
+
|
| 95 |
+
# Agent performs action
|
| 96 |
+
next_state, reward, done, info = env.step(action)
|
| 97 |
+
|
| 98 |
+
# Remember
|
| 99 |
+
agent.cache(state, next_state, action, reward, done)
|
| 100 |
+
|
| 101 |
+
# Learn
|
| 102 |
+
q, loss = agent.learn()
|
| 103 |
+
|
| 104 |
+
# Logging
|
| 105 |
+
logger.log_step(reward, loss, q)
|
| 106 |
+
|
| 107 |
+
# Update state
|
| 108 |
+
state = next_state
|
| 109 |
+
|
| 110 |
+
# Check if end of game
|
| 111 |
+
if done or info["gameover"] == 1:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
logger.log_episode(e)
|
| 115 |
+
|
| 116 |
+
if e % 20 == 0:
|
| 117 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
| 118 |
+
|
| 119 |
+
fill_memory(agent)
|
| 120 |
+
train(agent)
|
src/airstriker-genesis/run-airstriker-dqn.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import trange
|
| 8 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
| 9 |
+
from wrappers import make_env
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# set up matplotlib
|
| 13 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
| 14 |
+
if is_ipython:
|
| 15 |
+
from IPython import display
|
| 16 |
+
|
| 17 |
+
plt.ion()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
env = make_env()
|
| 21 |
+
|
| 22 |
+
use_cuda = torch.cuda.is_available()
|
| 23 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
checkpoint = None
|
| 27 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 28 |
+
|
| 29 |
+
path = "checkpoints/airstriker-dqn-new"
|
| 30 |
+
save_dir = Path(path)
|
| 31 |
+
|
| 32 |
+
isExist = os.path.exists(path)
|
| 33 |
+
if not isExist:
|
| 34 |
+
os.makedirs(path)
|
| 35 |
+
|
| 36 |
+
# Vanilla DQN
|
| 37 |
+
print("Training Vanilla DQN Agent!")
|
| 38 |
+
agent = DQNAgent(
|
| 39 |
+
state_dim=(1, 84, 84),
|
| 40 |
+
action_dim=env.action_space.n,
|
| 41 |
+
save_dir=save_dir,
|
| 42 |
+
batch_size=128,
|
| 43 |
+
checkpoint=checkpoint,
|
| 44 |
+
exploration_rate_decay=0.995,
|
| 45 |
+
exploration_rate_min=0.05,
|
| 46 |
+
training_frequency=1,
|
| 47 |
+
target_network_sync_frequency=500,
|
| 48 |
+
max_memory_size=50000,
|
| 49 |
+
learning_rate=0.0005,
|
| 50 |
+
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Double DQN
|
| 54 |
+
# print("Training DDQN Agent!")
|
| 55 |
+
# agent = DDQNAgent(
|
| 56 |
+
# state_dim=(1, 84, 84),
|
| 57 |
+
# action_dim=env.action_space.n,
|
| 58 |
+
# save_dir=save_dir,
|
| 59 |
+
# checkpoint=checkpoint,
|
| 60 |
+
# reset_exploration_rate=True,
|
| 61 |
+
# max_memory_size=max_memory_size
|
| 62 |
+
# )
|
| 63 |
+
|
| 64 |
+
logger = MetricLogger(save_dir)
|
| 65 |
+
|
| 66 |
+
def fill_memory(agent: DQNAgent, num_episodes=1000):
|
| 67 |
+
print("Filling up memory....")
|
| 68 |
+
for _ in trange(num_episodes):
|
| 69 |
+
state = env.reset()
|
| 70 |
+
done = False
|
| 71 |
+
while not done:
|
| 72 |
+
action = agent.act(state)
|
| 73 |
+
next_state, reward, done, _ = env.step(action)
|
| 74 |
+
agent.cache(state, next_state, action, reward, done)
|
| 75 |
+
state = next_state
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train(agent: DQNAgent):
|
| 79 |
+
episodes = 10000000
|
| 80 |
+
for e in range(episodes):
|
| 81 |
+
|
| 82 |
+
state = env.reset()
|
| 83 |
+
# Play the game!
|
| 84 |
+
while True:
|
| 85 |
+
|
| 86 |
+
# print(state.shape)
|
| 87 |
+
# Run agent on the state
|
| 88 |
+
action = agent.act(state)
|
| 89 |
+
|
| 90 |
+
# Agent performs action
|
| 91 |
+
next_state, reward, done, info = env.step(action)
|
| 92 |
+
|
| 93 |
+
# Remember
|
| 94 |
+
agent.cache(state, next_state, action, reward, done)
|
| 95 |
+
|
| 96 |
+
# Learn
|
| 97 |
+
q, loss = agent.learn()
|
| 98 |
+
|
| 99 |
+
# Logging
|
| 100 |
+
logger.log_step(reward, loss, q)
|
| 101 |
+
|
| 102 |
+
# Update state
|
| 103 |
+
state = next_state
|
| 104 |
+
|
| 105 |
+
# Check if end of game
|
| 106 |
+
if done or info["gameover"] == 1:
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
logger.log_episode(e)
|
| 110 |
+
|
| 111 |
+
if e % 20 == 0:
|
| 112 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
| 113 |
+
|
| 114 |
+
fill_memory(agent)
|
| 115 |
+
train(agent)
|
src/airstriker-genesis/run-cartpole.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random, datetime
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import retro as gym
|
| 5 |
+
from collections import namedtuple, deque
|
| 6 |
+
from itertools import count
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import matplotlib
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
# from agent import MyAgent, MyDQN, MetricLogger
|
| 12 |
+
from cartpole import MyAgent, MetricLogger
|
| 13 |
+
from wrappers import make_env
|
| 14 |
+
import pickle
|
| 15 |
+
import gym
|
| 16 |
+
from tqdm import trange
|
| 17 |
+
|
| 18 |
+
# set up matplotlib
|
| 19 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
| 20 |
+
if is_ipython:
|
| 21 |
+
from IPython import display
|
| 22 |
+
|
| 23 |
+
plt.ion()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# env = make_env()
|
| 27 |
+
env = gym.make('CartPole-v1')
|
| 28 |
+
|
| 29 |
+
use_cuda = torch.cuda.is_available()
|
| 30 |
+
print(f"Using CUDA: {use_cuda}")
|
| 31 |
+
print()
|
| 32 |
+
|
| 33 |
+
path = "checkpoints/cartpole/latest"
|
| 34 |
+
save_dir = Path(path)
|
| 35 |
+
|
| 36 |
+
isExist = os.path.exists(path)
|
| 37 |
+
if not isExist:
|
| 38 |
+
os.makedirs(path)
|
| 39 |
+
|
| 40 |
+
# save_dir.mkdir(parents=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
checkpoint = None
|
| 44 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 45 |
+
|
| 46 |
+
# For cartpole
|
| 47 |
+
n_actions = env.action_space.n
|
| 48 |
+
state = env.reset()
|
| 49 |
+
n_observations = len(state)
|
| 50 |
+
max_memory_size=100000
|
| 51 |
+
agent = MyAgent(
|
| 52 |
+
state_dim=n_observations,
|
| 53 |
+
action_dim=n_actions,
|
| 54 |
+
save_dir=save_dir,
|
| 55 |
+
checkpoint=checkpoint,
|
| 56 |
+
reset_exploration_rate=True,
|
| 57 |
+
max_memory_size=max_memory_size
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# For airstriker
|
| 61 |
+
# agent = MyAgent(state_dim=(1, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, checkpoint=checkpoint, reset_exploration_rate=True)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
logger = MetricLogger(save_dir)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def fill_memory(agent: MyAgent):
|
| 69 |
+
print("Filling up memory....")
|
| 70 |
+
for _ in trange(max_memory_size):
|
| 71 |
+
state = env.reset()
|
| 72 |
+
done = False
|
| 73 |
+
while not done:
|
| 74 |
+
action = agent.act(state)
|
| 75 |
+
next_state, reward, done, info = env.step(action)
|
| 76 |
+
agent.cache(state, next_state, action, reward, done)
|
| 77 |
+
state = next_state
|
| 78 |
+
|
| 79 |
+
def train(agent: MyAgent):
|
| 80 |
+
episodes = 10000000
|
| 81 |
+
for e in range(episodes):
|
| 82 |
+
|
| 83 |
+
state = env.reset()
|
| 84 |
+
# Play the game!
|
| 85 |
+
while True:
|
| 86 |
+
|
| 87 |
+
# print(state.shape)
|
| 88 |
+
# Run agent on the state
|
| 89 |
+
action = agent.act(state)
|
| 90 |
+
|
| 91 |
+
# Agent performs action
|
| 92 |
+
next_state, reward, done, info = env.step(action)
|
| 93 |
+
|
| 94 |
+
# Remember
|
| 95 |
+
agent.cache(state, next_state, action, reward, done)
|
| 96 |
+
|
| 97 |
+
# Learn
|
| 98 |
+
q, loss = agent.learn()
|
| 99 |
+
|
| 100 |
+
# Logging
|
| 101 |
+
logger.log_step(reward, loss, q)
|
| 102 |
+
|
| 103 |
+
# Update state
|
| 104 |
+
state = next_state
|
| 105 |
+
|
| 106 |
+
# # Check if end of game (for airstriker)
|
| 107 |
+
# if done or info["gameover"] == 1:
|
| 108 |
+
# break
|
| 109 |
+
# Check if end of game (for cartpole)
|
| 110 |
+
if done:
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
logger.log_episode(e)
|
| 114 |
+
|
| 115 |
+
if e % 20 == 0:
|
| 116 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
fill_memory(agent)
|
| 120 |
+
train(agent)
|
src/airstriker-genesis/test.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import retro
|
| 2 |
+
import gym
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from collections import namedtuple, deque
|
| 9 |
+
from itertools import count
|
| 10 |
+
from gym import spaces
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import cv2
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 22 |
+
def __init__(self, env, skip=4):
|
| 23 |
+
"""Return only every `skip`-th frame"""
|
| 24 |
+
gym.Wrapper.__init__(self, env)
|
| 25 |
+
# most recent raw observations (for max pooling across time steps)
|
| 26 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
| 27 |
+
self._skip = skip
|
| 28 |
+
|
| 29 |
+
def step(self, action):
|
| 30 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 31 |
+
total_reward = 0.0
|
| 32 |
+
done = None
|
| 33 |
+
for i in range(self._skip):
|
| 34 |
+
obs, reward, done, info = self.env.step(action)
|
| 35 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
| 36 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
| 37 |
+
total_reward += reward
|
| 38 |
+
if done:
|
| 39 |
+
break
|
| 40 |
+
# Note that the observation on the done=True frame
|
| 41 |
+
# doesn't matter
|
| 42 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 43 |
+
|
| 44 |
+
return max_frame, total_reward, done, info
|
| 45 |
+
|
| 46 |
+
def reset(self, **kwargs):
|
| 47 |
+
return self.env.reset(**kwargs)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class LazyFrames(object):
|
| 51 |
+
def __init__(self, frames):
|
| 52 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 53 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 54 |
+
buffers.
|
| 55 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 56 |
+
You'd not believe how complex the previous solution was."""
|
| 57 |
+
self._frames = frames
|
| 58 |
+
self._out = None
|
| 59 |
+
|
| 60 |
+
def _force(self):
|
| 61 |
+
if self._out is None:
|
| 62 |
+
self._out = np.concatenate(self._frames, axis=2)
|
| 63 |
+
self._frames = None
|
| 64 |
+
return self._out
|
| 65 |
+
|
| 66 |
+
def __array__(self, dtype=None):
|
| 67 |
+
out = self._force()
|
| 68 |
+
if dtype is not None:
|
| 69 |
+
out = out.astype(dtype)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return len(self._force())
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, i):
|
| 76 |
+
return self._force()[i]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class FrameStack(gym.Wrapper):
|
| 80 |
+
def __init__(self, env, k):
|
| 81 |
+
"""Stack k last frames.
|
| 82 |
+
Returns lazy array, which is much more memory efficient.
|
| 83 |
+
See Also
|
| 84 |
+
--------
|
| 85 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 86 |
+
"""
|
| 87 |
+
gym.Wrapper.__init__(self, env)
|
| 88 |
+
self.k = k
|
| 89 |
+
self.frames = deque([], maxlen=k)
|
| 90 |
+
shp = env.observation_space.shape
|
| 91 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
| 92 |
+
|
| 93 |
+
def reset(self):
|
| 94 |
+
ob = self.env.reset()
|
| 95 |
+
for _ in range(self.k):
|
| 96 |
+
self.frames.append(ob)
|
| 97 |
+
return self._get_ob()
|
| 98 |
+
|
| 99 |
+
def step(self, action):
|
| 100 |
+
ob, reward, done, info = self.env.step(action)
|
| 101 |
+
self.frames.append(ob)
|
| 102 |
+
return self._get_ob(), reward, done, info
|
| 103 |
+
|
| 104 |
+
def _get_ob(self):
|
| 105 |
+
assert len(self.frames) == self.k
|
| 106 |
+
return LazyFrames(list(self.frames))
|
| 107 |
+
|
| 108 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 109 |
+
def __init__(self, env):
|
| 110 |
+
gym.RewardWrapper.__init__(self, env)
|
| 111 |
+
|
| 112 |
+
def reward(self, reward):
|
| 113 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 114 |
+
return np.sign(reward)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
| 118 |
+
def __init__(self, env):
|
| 119 |
+
super(ImageToPyTorch, self).__init__(env)
|
| 120 |
+
old_shape = self.observation_space.shape
|
| 121 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
| 122 |
+
|
| 123 |
+
def observation(self, observation):
|
| 124 |
+
return np.moveaxis(observation, 2, 0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 128 |
+
def __init__(self, env):
|
| 129 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 130 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 131 |
+
self.width = 84
|
| 132 |
+
self.height = 84
|
| 133 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
| 134 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
| 135 |
+
|
| 136 |
+
def observation(self, frame):
|
| 137 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 138 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 139 |
+
return frame[:, :, None]
|
| 140 |
+
|
| 141 |
+
class AirstrikerDiscretizer(gym.ActionWrapper):
|
| 142 |
+
# 初期化
|
| 143 |
+
def __init__(self, env):
|
| 144 |
+
super(AirstrikerDiscretizer, self).__init__(env)
|
| 145 |
+
buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
|
| 146 |
+
actions = [['LEFT'], ['RIGHT'], ['B']]
|
| 147 |
+
self._actions = []
|
| 148 |
+
for action in actions:
|
| 149 |
+
arr = np.array([False] * 12)
|
| 150 |
+
for button in action:
|
| 151 |
+
arr[buttons.index(button)] = True
|
| 152 |
+
self._actions.append(arr)
|
| 153 |
+
self.action_space = gym.spaces.Discrete(len(self._actions))
|
| 154 |
+
|
| 155 |
+
# 行動の取得
|
| 156 |
+
def action(self, a):
|
| 157 |
+
return self._actions[a].copy()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
env = retro.make(game='Airstriker-Genesis')
|
| 161 |
+
env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
|
| 162 |
+
env = WarpFrame(env) ## Reshape image
|
| 163 |
+
env = ImageToPyTorch(env) ## Invert shape
|
| 164 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
| 165 |
+
# env = ScaledFloatFrame(env) ## Scale frames
|
| 166 |
+
env = AirstrikerDiscretizer(env)
|
| 167 |
+
env = ClipRewardEnv(env)
|
| 168 |
+
|
| 169 |
+
# set up matplotlib
|
| 170 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
| 171 |
+
if is_ipython:
|
| 172 |
+
from IPython import display
|
| 173 |
+
|
| 174 |
+
plt.ion()
|
| 175 |
+
|
| 176 |
+
# if gpu is to be used
|
| 177 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 178 |
+
|
| 179 |
+
Transition = namedtuple('Transition',
|
| 180 |
+
('state', 'action', 'next_state', 'reward'))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class ReplayMemory(object):
|
| 184 |
+
|
| 185 |
+
def __init__(self, capacity):
|
| 186 |
+
self.memory = deque([],maxlen=capacity)
|
| 187 |
+
|
| 188 |
+
def push(self, *args):
|
| 189 |
+
"""Save a transition"""
|
| 190 |
+
self.memory.append(Transition(*args))
|
| 191 |
+
|
| 192 |
+
def sample(self, batch_size):
|
| 193 |
+
return random.sample(self.memory, batch_size)
|
| 194 |
+
|
| 195 |
+
def __len__(self):
|
| 196 |
+
return len(self.memory)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class DQN(nn.Module):
|
| 200 |
+
|
| 201 |
+
def __init__(self, n_observations, n_actions):
|
| 202 |
+
super(DQN, self).__init__()
|
| 203 |
+
# self.layer1 = nn.Linear(n_observations, 128)
|
| 204 |
+
# self.layer2 = nn.Linear(128, 128)
|
| 205 |
+
# self.layer3 = nn.Linear(128, n_actions)
|
| 206 |
+
|
| 207 |
+
self.layer1 = nn.Conv2d(in_channels=n_observations, out_channels=32, kernel_size=8, stride=4)
|
| 208 |
+
self.layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
|
| 209 |
+
self.layer3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), nn.Flatten())
|
| 210 |
+
self.layer4 = nn.Linear(17024, 512)
|
| 211 |
+
self.layer5 = nn.Linear(512, n_actions)
|
| 212 |
+
|
| 213 |
+
# Called with either one element to determine next action, or a batch
|
| 214 |
+
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
x = F.relu(self.layer1(x))
|
| 217 |
+
x = F.relu(self.layer2(x))
|
| 218 |
+
x = F.relu(self.layer3(x))
|
| 219 |
+
x = F.relu(self.layer4(x))
|
| 220 |
+
return self.layer5(x)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
| 224 |
+
# GAMMA is the discount factor as mentioned in the previous section
|
| 225 |
+
# EPS_START is the starting value of epsilon
|
| 226 |
+
# EPS_END is the final value of epsilon
|
| 227 |
+
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
| 228 |
+
# TAU is the update rate of the target network
|
| 229 |
+
# LR is the learning rate of the AdamW optimizer
|
| 230 |
+
BATCH_SIZE = 512
|
| 231 |
+
GAMMA = 0.99
|
| 232 |
+
EPS_START = 1
|
| 233 |
+
EPS_END = 0.01
|
| 234 |
+
EPS_DECAY = 10000
|
| 235 |
+
TAU = 0.005
|
| 236 |
+
# LR = 1e-4
|
| 237 |
+
LR = 0.00025
|
| 238 |
+
|
| 239 |
+
# Get number of actions from gym action space
|
| 240 |
+
n_actions = env.action_space.n
|
| 241 |
+
state = env.reset()
|
| 242 |
+
n_observations = len(state)
|
| 243 |
+
|
| 244 |
+
policy_net = DQN(n_observations, n_actions).to(device)
|
| 245 |
+
target_net = DQN(n_observations, n_actions).to(device)
|
| 246 |
+
target_net.load_state_dict(policy_net.state_dict())
|
| 247 |
+
|
| 248 |
+
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
| 249 |
+
memory = ReplayMemory(10000)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
steps_done = 0
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def select_action(state):
|
| 256 |
+
global steps_done
|
| 257 |
+
sample = random.random()
|
| 258 |
+
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
|
| 259 |
+
steps_done += 1
|
| 260 |
+
if sample > eps_threshold:
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
# t.max(1) will return largest column value of each row.
|
| 263 |
+
# second column on max result is index of where max element was
|
| 264 |
+
# found, so we pick action with the larger expected reward.
|
| 265 |
+
return policy_net(state).max(1)[1].view(1, 1), eps_threshold
|
| 266 |
+
else:
|
| 267 |
+
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long), eps_threshold
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
episode_durations = []
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def plot_durations(show_result=False):
|
| 274 |
+
plt.figure(1)
|
| 275 |
+
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
| 276 |
+
if show_result:
|
| 277 |
+
plt.title('Result')
|
| 278 |
+
else:
|
| 279 |
+
plt.clf()
|
| 280 |
+
plt.title('Training...')
|
| 281 |
+
plt.xlabel('Episode')
|
| 282 |
+
plt.ylabel('Duration')
|
| 283 |
+
plt.plot(durations_t.numpy())
|
| 284 |
+
# Take 100 episode averages and plot them too
|
| 285 |
+
if len(durations_t) >= 100:
|
| 286 |
+
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
|
| 287 |
+
means = torch.cat((torch.zeros(99), means))
|
| 288 |
+
plt.plot(means.numpy())
|
| 289 |
+
|
| 290 |
+
plt.pause(0.001) # pause a bit so that plots are updated
|
| 291 |
+
if is_ipython:
|
| 292 |
+
if not show_result:
|
| 293 |
+
display.display(plt.gcf())
|
| 294 |
+
display.clear_output(wait=True)
|
| 295 |
+
else:
|
| 296 |
+
display.display(plt.gcf())
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def optimize_model():
|
| 301 |
+
if len(memory) < BATCH_SIZE:
|
| 302 |
+
return
|
| 303 |
+
transitions = memory.sample(BATCH_SIZE)
|
| 304 |
+
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
| 305 |
+
# detailed explanation). This converts batch-array of Transitions
|
| 306 |
+
# to Transition of batch-arrays.
|
| 307 |
+
batch = Transition(*zip(*transitions))
|
| 308 |
+
|
| 309 |
+
# Compute a mask of non-final states and concatenate the batch elements
|
| 310 |
+
# (a final state would've been the one after which simulation ended)
|
| 311 |
+
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
| 312 |
+
batch.next_state)), device=device, dtype=torch.bool)
|
| 313 |
+
non_final_next_states = torch.cat([s for s in batch.next_state
|
| 314 |
+
if s is not None])
|
| 315 |
+
state_batch = torch.cat(batch.state)
|
| 316 |
+
action_batch = torch.cat(batch.action)
|
| 317 |
+
reward_batch = torch.cat(batch.reward)
|
| 318 |
+
|
| 319 |
+
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
| 320 |
+
# columns of actions taken. These are the actions which would've been taken
|
| 321 |
+
# for each batch state according to policy_net
|
| 322 |
+
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
| 323 |
+
|
| 324 |
+
# Compute V(s_{t+1}) for all next states.
|
| 325 |
+
# Expected values of actions for non_final_next_states are computed based
|
| 326 |
+
# on the "older" target_net; selecting their best reward with max(1)[0].
|
| 327 |
+
# This is merged based on the mask, such that we'll have either the expected
|
| 328 |
+
# state value or 0 in case the state was final.
|
| 329 |
+
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
| 332 |
+
# Compute the expected Q values
|
| 333 |
+
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
| 334 |
+
|
| 335 |
+
# Compute Huber loss
|
| 336 |
+
criterion = nn.SmoothL1Loss()
|
| 337 |
+
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
| 338 |
+
|
| 339 |
+
# Optimize the model
|
| 340 |
+
optimizer.zero_grad()
|
| 341 |
+
loss.backward()
|
| 342 |
+
# In-place gradient clipping
|
| 343 |
+
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
| 344 |
+
optimizer.step()
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
with SummaryWriter() as writer:
|
| 348 |
+
if torch.cuda.is_available():
|
| 349 |
+
num_episodes = 600
|
| 350 |
+
else:
|
| 351 |
+
num_episodes = 50
|
| 352 |
+
epsilon = 1
|
| 353 |
+
episode_rewards = []
|
| 354 |
+
for i_episode in range(num_episodes):
|
| 355 |
+
|
| 356 |
+
# Initialize the environment and get it's state
|
| 357 |
+
state = env.reset()
|
| 358 |
+
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
|
| 359 |
+
episode_reward = 0
|
| 360 |
+
for t in count():
|
| 361 |
+
action, epsilon = select_action(state)
|
| 362 |
+
observation, reward, done, info = env.step(action.item())
|
| 363 |
+
reward = torch.tensor([reward], device=device)
|
| 364 |
+
|
| 365 |
+
done = done or info["gameover"] == 1
|
| 366 |
+
if done:
|
| 367 |
+
episode_durations.append(t + 1)
|
| 368 |
+
print(f"Episode {i_episode} done")
|
| 369 |
+
# plot_durations()
|
| 370 |
+
break
|
| 371 |
+
# if done:
|
| 372 |
+
# next_state = None
|
| 373 |
+
# else:
|
| 374 |
+
# next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
| 375 |
+
|
| 376 |
+
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
| 377 |
+
|
| 378 |
+
# Store the transition in memory
|
| 379 |
+
memory.push(state, action, next_state, reward)
|
| 380 |
+
episode_reward += reward
|
| 381 |
+
# Move to the next state
|
| 382 |
+
state = next_state
|
| 383 |
+
|
| 384 |
+
# Perform one step of the optimization (on the policy network)
|
| 385 |
+
optimize_model()
|
| 386 |
+
|
| 387 |
+
# Soft update of the target network's weights
|
| 388 |
+
# θ′ ← τ θ + (1 −τ )θ′
|
| 389 |
+
target_net_state_dict = target_net.state_dict()
|
| 390 |
+
policy_net_state_dict = policy_net.state_dict()
|
| 391 |
+
for key in policy_net_state_dict:
|
| 392 |
+
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
|
| 393 |
+
target_net.load_state_dict(target_net_state_dict)
|
| 394 |
+
# if done:
|
| 395 |
+
# episode_durations.append(t + 1)
|
| 396 |
+
# # plot_durations()
|
| 397 |
+
# break
|
| 398 |
+
# episode_rewards.append(episode_reward)
|
| 399 |
+
writer.add_scalar("Rewards/Episode", episode_reward, i_episode)
|
| 400 |
+
writer.add_scalar("Epsilon", epsilon, i_episode)
|
| 401 |
+
writer.flush()
|
| 402 |
+
print('Complete')
|
| 403 |
+
plot_durations(show_result=True)
|
| 404 |
+
plt.ioff()
|
| 405 |
+
plt.show()
|
src/airstriker-genesis/utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Airstrikerラッパー
|
| 6 |
+
class AirstrikerDiscretizer(gym.ActionWrapper):
|
| 7 |
+
# 初期化
|
| 8 |
+
def __init__(self, env):
|
| 9 |
+
super(AirstrikerDiscretizer, self).__init__(env)
|
| 10 |
+
buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
|
| 11 |
+
actions = [['LEFT'], ['RIGHT'], ['B']]
|
| 12 |
+
self._actions = []
|
| 13 |
+
for action in actions:
|
| 14 |
+
arr = np.array([False] * 12)
|
| 15 |
+
for button in action:
|
| 16 |
+
arr[buttons.index(button)] = True
|
| 17 |
+
self._actions.append(arr)
|
| 18 |
+
self.action_space = gym.spaces.Discrete(len(self._actions))
|
| 19 |
+
|
| 20 |
+
# 行動の取得
|
| 21 |
+
def action(self, a):
|
| 22 |
+
return self._actions[a].copy()
|
src/airstriker-genesis/wrappers.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from collections import deque
|
| 4 |
+
import gym
|
| 5 |
+
from gym import spaces
|
| 6 |
+
import cv2
|
| 7 |
+
import retro
|
| 8 |
+
from utils import AirstrikerDiscretizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
'''
|
| 12 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LazyFrames(object):
|
| 17 |
+
def __init__(self, frames):
|
| 18 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 19 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 20 |
+
buffers.
|
| 21 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 22 |
+
You'd not believe how complex the previous solution was."""
|
| 23 |
+
self._frames = frames
|
| 24 |
+
self._out = None
|
| 25 |
+
|
| 26 |
+
def _force(self):
|
| 27 |
+
if self._out is None:
|
| 28 |
+
self._out = np.concatenate(self._frames, axis=2)
|
| 29 |
+
self._frames = None
|
| 30 |
+
return self._out
|
| 31 |
+
|
| 32 |
+
def __array__(self, dtype=None):
|
| 33 |
+
out = self._force()
|
| 34 |
+
if dtype is not None:
|
| 35 |
+
out = out.astype(dtype)
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self._force())
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, i):
|
| 42 |
+
return self._force()[i]
|
| 43 |
+
|
| 44 |
+
class FireResetEnv(gym.Wrapper):
|
| 45 |
+
def __init__(self, env):
|
| 46 |
+
"""Take action on reset for environments that are fixed until firing."""
|
| 47 |
+
gym.Wrapper.__init__(self, env)
|
| 48 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
| 49 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 50 |
+
|
| 51 |
+
def reset(self, **kwargs):
|
| 52 |
+
self.env.reset(**kwargs)
|
| 53 |
+
obs, _, done, _ = self.env.step(1)
|
| 54 |
+
if done:
|
| 55 |
+
self.env.reset(**kwargs)
|
| 56 |
+
obs, _, done, _ = self.env.step(2)
|
| 57 |
+
if done:
|
| 58 |
+
self.env.reset(**kwargs)
|
| 59 |
+
return obs
|
| 60 |
+
|
| 61 |
+
def step(self, ac):
|
| 62 |
+
return self.env.step(ac)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 66 |
+
def __init__(self, env, skip=4):
|
| 67 |
+
"""Return only every `skip`-th frame"""
|
| 68 |
+
gym.Wrapper.__init__(self, env)
|
| 69 |
+
# most recent raw observations (for max pooling across time steps)
|
| 70 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
| 71 |
+
self._skip = skip
|
| 72 |
+
|
| 73 |
+
def step(self, action):
|
| 74 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 75 |
+
total_reward = 0.0
|
| 76 |
+
done = None
|
| 77 |
+
for i in range(self._skip):
|
| 78 |
+
obs, reward, done, info = self.env.step(action)
|
| 79 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
| 80 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
| 81 |
+
total_reward += reward
|
| 82 |
+
if done:
|
| 83 |
+
break
|
| 84 |
+
# Note that the observation on the done=True frame
|
| 85 |
+
# doesn't matter
|
| 86 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 87 |
+
|
| 88 |
+
return max_frame, total_reward, done, info
|
| 89 |
+
|
| 90 |
+
def reset(self, **kwargs):
|
| 91 |
+
return self.env.reset(**kwargs)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 96 |
+
def __init__(self, env):
|
| 97 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 98 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 99 |
+
self.width = 84
|
| 100 |
+
self.height = 84
|
| 101 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
| 102 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
| 103 |
+
|
| 104 |
+
def observation(self, frame):
|
| 105 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 106 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 107 |
+
return frame[:, :, None]
|
| 108 |
+
|
| 109 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
| 110 |
+
def __init__(self, env):
|
| 111 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 112 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 113 |
+
|
| 114 |
+
def observation(self, frame):
|
| 115 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 116 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 117 |
+
return frame[:, :, None]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FrameStack(gym.Wrapper):
|
| 122 |
+
def __init__(self, env, k):
|
| 123 |
+
"""Stack k last frames.
|
| 124 |
+
Returns lazy array, which is much more memory efficient.
|
| 125 |
+
See Also
|
| 126 |
+
--------
|
| 127 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 128 |
+
"""
|
| 129 |
+
gym.Wrapper.__init__(self, env)
|
| 130 |
+
self.k = k
|
| 131 |
+
self.frames = deque([], maxlen=k)
|
| 132 |
+
shp = env.observation_space.shape
|
| 133 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
| 134 |
+
|
| 135 |
+
def reset(self):
|
| 136 |
+
ob = self.env.reset()
|
| 137 |
+
for _ in range(self.k):
|
| 138 |
+
self.frames.append(ob)
|
| 139 |
+
return self._get_ob()
|
| 140 |
+
|
| 141 |
+
def step(self, action):
|
| 142 |
+
ob, reward, done, info = self.env.step(action)
|
| 143 |
+
self.frames.append(ob)
|
| 144 |
+
return self._get_ob(), reward, done, info
|
| 145 |
+
|
| 146 |
+
def _get_ob(self):
|
| 147 |
+
assert len(self.frames) == self.k
|
| 148 |
+
return LazyFrames(list(self.frames))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
| 152 |
+
def __init__(self, env):
|
| 153 |
+
super(ImageToPyTorch, self).__init__(env)
|
| 154 |
+
old_shape = self.observation_space.shape
|
| 155 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
| 156 |
+
|
| 157 |
+
def observation(self, observation):
|
| 158 |
+
return np.moveaxis(observation, 2, 0)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# class ImageToPyTorch(gym.ObservationWrapper):
|
| 162 |
+
# def __init__(self, env):
|
| 163 |
+
# super(ImageToPyTorch, self).__init__(env)
|
| 164 |
+
# old_shape = self.observation_space.shape
|
| 165 |
+
# new_shape = (old_shape[-1], old_shape[0], old_shape[1])
|
| 166 |
+
# print("Old: ", old_shape)
|
| 167 |
+
# print("New: ", new_shape)
|
| 168 |
+
# self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=new_shape, dtype=np.float32)
|
| 169 |
+
|
| 170 |
+
# def observation(self, observation):
|
| 171 |
+
# return np.moveaxis(observation, 2, 0)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
| 175 |
+
def __init__(self, env):
|
| 176 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 177 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
| 178 |
+
|
| 179 |
+
def observation(self, observation):
|
| 180 |
+
# careful! This undoes the memory optimization, use
|
| 181 |
+
# with smaller replay buffers only.
|
| 182 |
+
return np.array(observation).astype(np.float32) / 255.0
|
| 183 |
+
|
| 184 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 185 |
+
def __init__(self, env):
|
| 186 |
+
gym.RewardWrapper.__init__(self, env)
|
| 187 |
+
|
| 188 |
+
def reward(self, reward):
|
| 189 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 190 |
+
return np.sign(reward)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def make_env():
|
| 194 |
+
|
| 195 |
+
env = retro.make(game='Airstriker-Genesis')
|
| 196 |
+
env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
|
| 197 |
+
env = WarpFrame(env) ## Reshape image
|
| 198 |
+
env = ImageToPyTorch(env) ## Invert shape
|
| 199 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
| 200 |
+
env = ScaledFloatFrame(env) ## Scale frames
|
| 201 |
+
env = AirstrikerDiscretizer(env)
|
| 202 |
+
env = ClipRewardEnv(env)
|
| 203 |
+
return env
|
| 204 |
+
|
| 205 |
+
def make_starpilot(render=False):
|
| 206 |
+
if render:
|
| 207 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
|
| 208 |
+
else:
|
| 209 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
|
| 210 |
+
env = WarpFrameNoResize(env) ## Reshape image
|
| 211 |
+
env = ImageToPyTorch(env) ## Invert shape
|
| 212 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
| 213 |
+
return env
|
src/lunar-lander/agent.py
ADDED
|
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import copy
|
| 6 |
+
import time, datetime
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from collections import deque
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DQNet(nn.Module):
|
| 13 |
+
"""mini cnn structure"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, input_dim, output_dim):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.online = nn.Sequential(
|
| 19 |
+
nn.Linear(input_dim, 150),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Linear(150, 120),
|
| 22 |
+
nn.ReLU(),
|
| 23 |
+
nn.Linear(120, output_dim),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
self.target = copy.deepcopy(self.online)
|
| 28 |
+
|
| 29 |
+
# Q_target parameters are frozen.
|
| 30 |
+
for p in self.target.parameters():
|
| 31 |
+
p.requires_grad = False
|
| 32 |
+
|
| 33 |
+
def forward(self, input, model):
|
| 34 |
+
if model == "online":
|
| 35 |
+
return self.online(input)
|
| 36 |
+
elif model == "target":
|
| 37 |
+
return self.target(input)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MetricLogger:
|
| 42 |
+
def __init__(self, save_dir):
|
| 43 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
| 44 |
+
self.save_log = save_dir / "log"
|
| 45 |
+
with open(self.save_log, "w") as f:
|
| 46 |
+
f.write(
|
| 47 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
| 48 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
| 49 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
| 50 |
+
)
|
| 51 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
| 52 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
| 53 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
| 54 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
| 55 |
+
|
| 56 |
+
# History metrics
|
| 57 |
+
self.ep_rewards = []
|
| 58 |
+
self.ep_lengths = []
|
| 59 |
+
self.ep_avg_losses = []
|
| 60 |
+
self.ep_avg_qs = []
|
| 61 |
+
|
| 62 |
+
# Moving averages, added for every call to record()
|
| 63 |
+
self.moving_avg_ep_rewards = []
|
| 64 |
+
self.moving_avg_ep_lengths = []
|
| 65 |
+
self.moving_avg_ep_avg_losses = []
|
| 66 |
+
self.moving_avg_ep_avg_qs = []
|
| 67 |
+
|
| 68 |
+
# Current episode metric
|
| 69 |
+
self.init_episode()
|
| 70 |
+
|
| 71 |
+
# Timing
|
| 72 |
+
self.record_time = time.time()
|
| 73 |
+
|
| 74 |
+
def log_step(self, reward, loss, q):
|
| 75 |
+
self.curr_ep_reward += reward
|
| 76 |
+
self.curr_ep_length += 1
|
| 77 |
+
if loss:
|
| 78 |
+
self.curr_ep_loss += loss
|
| 79 |
+
self.curr_ep_q += q
|
| 80 |
+
self.curr_ep_loss_length += 1
|
| 81 |
+
|
| 82 |
+
def log_episode(self, episode_number):
|
| 83 |
+
"Mark end of episode"
|
| 84 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
| 85 |
+
self.ep_lengths.append(self.curr_ep_length)
|
| 86 |
+
if self.curr_ep_loss_length == 0:
|
| 87 |
+
ep_avg_loss = 0
|
| 88 |
+
ep_avg_q = 0
|
| 89 |
+
else:
|
| 90 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
| 91 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
| 92 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
| 93 |
+
self.ep_avg_qs.append(ep_avg_q)
|
| 94 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
| 95 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
| 96 |
+
self.writer.flush()
|
| 97 |
+
self.init_episode()
|
| 98 |
+
|
| 99 |
+
def init_episode(self):
|
| 100 |
+
self.curr_ep_reward = 0.0
|
| 101 |
+
self.curr_ep_length = 0
|
| 102 |
+
self.curr_ep_loss = 0.0
|
| 103 |
+
self.curr_ep_q = 0.0
|
| 104 |
+
self.curr_ep_loss_length = 0
|
| 105 |
+
|
| 106 |
+
def record(self, episode, epsilon, step):
|
| 107 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
| 108 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
| 109 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
| 110 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
| 111 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
| 112 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
| 113 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
| 114 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
| 115 |
+
|
| 116 |
+
last_record_time = self.record_time
|
| 117 |
+
self.record_time = time.time()
|
| 118 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
| 119 |
+
|
| 120 |
+
print(
|
| 121 |
+
f"Episode {episode} - "
|
| 122 |
+
f"Step {step} - "
|
| 123 |
+
f"Epsilon {epsilon} - "
|
| 124 |
+
f"Mean Reward {mean_ep_reward} - "
|
| 125 |
+
f"Mean Length {mean_ep_length} - "
|
| 126 |
+
f"Mean Loss {mean_ep_loss} - "
|
| 127 |
+
f"Mean Q Value {mean_ep_q} - "
|
| 128 |
+
f"Time Delta {time_since_last_record} - "
|
| 129 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
| 130 |
+
)
|
| 131 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 132 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
| 133 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
| 134 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 135 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
| 136 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
| 137 |
+
self.writer.flush()
|
| 138 |
+
with open(self.save_log, "a") as f:
|
| 139 |
+
f.write(
|
| 140 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
| 141 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
| 142 |
+
f"{time_since_last_record:15.3f}"
|
| 143 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
| 147 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
| 148 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
| 149 |
+
plt.clf()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class DQNAgent:
|
| 153 |
+
def __init__(self,
|
| 154 |
+
state_dim,
|
| 155 |
+
action_dim,
|
| 156 |
+
save_dir,
|
| 157 |
+
checkpoint=None,
|
| 158 |
+
learning_rate=0.00025,
|
| 159 |
+
max_memory_size=100000,
|
| 160 |
+
batch_size=32,
|
| 161 |
+
exploration_rate=1,
|
| 162 |
+
exploration_rate_decay=0.9999999,
|
| 163 |
+
exploration_rate_min=0.1,
|
| 164 |
+
training_frequency=1,
|
| 165 |
+
learning_starts=1000,
|
| 166 |
+
target_network_sync_frequency=500,
|
| 167 |
+
reset_exploration_rate=False,
|
| 168 |
+
save_frequency=100000,
|
| 169 |
+
gamma=0.9,
|
| 170 |
+
load_replay_buffer=True):
|
| 171 |
+
self.state_dim = state_dim
|
| 172 |
+
self.action_dim = action_dim
|
| 173 |
+
self.max_memory_size = max_memory_size
|
| 174 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 175 |
+
self.batch_size = batch_size
|
| 176 |
+
|
| 177 |
+
self.exploration_rate = exploration_rate
|
| 178 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 179 |
+
self.exploration_rate_min = exploration_rate_min
|
| 180 |
+
self.gamma = gamma
|
| 181 |
+
|
| 182 |
+
self.curr_step = 0
|
| 183 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 184 |
+
|
| 185 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 186 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 187 |
+
|
| 188 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 189 |
+
self.save_dir = save_dir
|
| 190 |
+
|
| 191 |
+
self.use_cuda = torch.cuda.is_available()
|
| 192 |
+
|
| 193 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
| 194 |
+
if self.use_cuda:
|
| 195 |
+
self.net = self.net.to(device='cuda')
|
| 196 |
+
if checkpoint:
|
| 197 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 198 |
+
|
| 199 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
| 200 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 201 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
|
| 202 |
+
# self.loss_fn = torch.nn.MSELoss()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def act(self, state):
|
| 206 |
+
"""
|
| 207 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 208 |
+
|
| 209 |
+
Inputs:
|
| 210 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 211 |
+
Outputs:
|
| 212 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 213 |
+
"""
|
| 214 |
+
# EXPLORE
|
| 215 |
+
if np.random.rand() < self.exploration_rate:
|
| 216 |
+
action_idx = np.random.randint(self.action_dim)
|
| 217 |
+
|
| 218 |
+
# EXPLOIT
|
| 219 |
+
else:
|
| 220 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 221 |
+
state = state.unsqueeze(0)
|
| 222 |
+
action_values = self.net(state, model='online')
|
| 223 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 224 |
+
|
| 225 |
+
# decrease exploration_rate
|
| 226 |
+
|
| 227 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 228 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 229 |
+
|
| 230 |
+
# increment step
|
| 231 |
+
self.curr_step += 1
|
| 232 |
+
return action_idx
|
| 233 |
+
|
| 234 |
+
def cache(self, state, next_state, action, reward, done):
|
| 235 |
+
"""
|
| 236 |
+
Store the experience to self.memory (replay buffer)
|
| 237 |
+
|
| 238 |
+
Inputs:
|
| 239 |
+
state (LazyFrame),
|
| 240 |
+
next_state (LazyFrame),
|
| 241 |
+
action (int),
|
| 242 |
+
reward (float),
|
| 243 |
+
done(bool))
|
| 244 |
+
"""
|
| 245 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 246 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 247 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 248 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 249 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 250 |
+
|
| 251 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def recall(self):
|
| 255 |
+
"""
|
| 256 |
+
Retrieve a batch of experiences from memory
|
| 257 |
+
"""
|
| 258 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 259 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 260 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def td_estimate(self, states, actions):
|
| 264 |
+
actions = actions.reshape(-1, 1)
|
| 265 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 266 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 267 |
+
return predicted_qs
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@torch.no_grad()
|
| 271 |
+
def td_target(self, rewards, next_states, dones):
|
| 272 |
+
rewards = rewards.reshape(-1, 1)
|
| 273 |
+
dones = dones.reshape(-1, 1)
|
| 274 |
+
target_qs = self.net(next_states, model='target')
|
| 275 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 276 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 277 |
+
target_qs[dones] = 0.0
|
| 278 |
+
return (rewards + (self.gamma * target_qs))
|
| 279 |
+
|
| 280 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 281 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
| 282 |
+
self.optimizer.zero_grad()
|
| 283 |
+
loss.backward()
|
| 284 |
+
self.optimizer.step()
|
| 285 |
+
return loss.item()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def sync_Q_target(self):
|
| 289 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def learn(self):
|
| 293 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 294 |
+
self.sync_Q_target()
|
| 295 |
+
|
| 296 |
+
if self.curr_step % self.save_every == 0:
|
| 297 |
+
self.save()
|
| 298 |
+
|
| 299 |
+
if self.curr_step < self.learning_starts:
|
| 300 |
+
return None, None
|
| 301 |
+
|
| 302 |
+
if self.curr_step % self.training_frequency != 0:
|
| 303 |
+
return None, None
|
| 304 |
+
|
| 305 |
+
# Sample from memory
|
| 306 |
+
state, next_state, action, reward, done = self.recall()
|
| 307 |
+
|
| 308 |
+
# Get TD Estimate
|
| 309 |
+
td_est = self.td_estimate(state, action)
|
| 310 |
+
|
| 311 |
+
# Get TD Target
|
| 312 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 313 |
+
|
| 314 |
+
# Backpropagate loss through Q_online
|
| 315 |
+
|
| 316 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 317 |
+
|
| 318 |
+
return (td_est.mean().item(), loss)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def save(self):
|
| 322 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 323 |
+
torch.save(
|
| 324 |
+
dict(
|
| 325 |
+
model=self.net.state_dict(),
|
| 326 |
+
exploration_rate=self.exploration_rate,
|
| 327 |
+
replay_memory=self.memory
|
| 328 |
+
),
|
| 329 |
+
save_path
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 336 |
+
if not load_path.exists():
|
| 337 |
+
raise ValueError(f"{load_path} does not exist")
|
| 338 |
+
|
| 339 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 340 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 341 |
+
state_dict = ckp.get('model')
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 345 |
+
self.net.load_state_dict(state_dict)
|
| 346 |
+
|
| 347 |
+
if load_replay_buffer:
|
| 348 |
+
replay_memory = ckp.get('replay_memory')
|
| 349 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 350 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 351 |
+
|
| 352 |
+
if reset_exploration_rate:
|
| 353 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 354 |
+
else:
|
| 355 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 356 |
+
self.exploration_rate = exploration_rate
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class DDQNAgent(DQNAgent):
|
| 360 |
+
@torch.no_grad()
|
| 361 |
+
def td_target(self, rewards, next_states, dones):
|
| 362 |
+
rewards = rewards.reshape(-1, 1)
|
| 363 |
+
dones = dones.reshape(-1, 1)
|
| 364 |
+
q_vals = self.net(next_states, model='online')
|
| 365 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 366 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 367 |
+
|
| 368 |
+
target_qs = self.net(next_states, model='target')
|
| 369 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 370 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 371 |
+
target_qs[dones] = 0.0
|
| 372 |
+
return (rewards + (self.gamma * target_qs))
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DuelingDQNet(nn.Module):
|
| 376 |
+
def __init__(self, input_dim, output_dim):
|
| 377 |
+
super().__init__()
|
| 378 |
+
self.feature_layer = nn.Sequential(
|
| 379 |
+
nn.Linear(input_dim, 150),
|
| 380 |
+
nn.ReLU(),
|
| 381 |
+
nn.Linear(150, 120),
|
| 382 |
+
nn.ReLU()
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
self.value_layer = nn.Sequential(
|
| 386 |
+
nn.Linear(120, 120),
|
| 387 |
+
nn.ReLU(),
|
| 388 |
+
nn.Linear(120, 1)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
self.advantage_layer = nn.Sequential(
|
| 392 |
+
nn.Linear(120, 120),
|
| 393 |
+
nn.ReLU(),
|
| 394 |
+
nn.Linear(120, output_dim)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
def forward(self, state):
|
| 398 |
+
feature_output = self.feature_layer(state)
|
| 399 |
+
# feature_output = feature_output.view(feature_output.size(0), -1)
|
| 400 |
+
value = self.value_layer(feature_output)
|
| 401 |
+
advantage = self.advantage_layer(feature_output)
|
| 402 |
+
q_value = value + (advantage - advantage.mean())
|
| 403 |
+
|
| 404 |
+
return q_value
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class DuelingDQNAgent:
|
| 408 |
+
def __init__(self,
|
| 409 |
+
state_dim,
|
| 410 |
+
action_dim,
|
| 411 |
+
save_dir,
|
| 412 |
+
checkpoint=None,
|
| 413 |
+
learning_rate=0.00025,
|
| 414 |
+
max_memory_size=100000,
|
| 415 |
+
batch_size=32,
|
| 416 |
+
exploration_rate=1,
|
| 417 |
+
exploration_rate_decay=0.9999999,
|
| 418 |
+
exploration_rate_min=0.1,
|
| 419 |
+
training_frequency=1,
|
| 420 |
+
learning_starts=1000,
|
| 421 |
+
target_network_sync_frequency=500,
|
| 422 |
+
reset_exploration_rate=False,
|
| 423 |
+
save_frequency=100000,
|
| 424 |
+
gamma=0.9,
|
| 425 |
+
load_replay_buffer=True):
|
| 426 |
+
self.state_dim = state_dim
|
| 427 |
+
self.action_dim = action_dim
|
| 428 |
+
self.max_memory_size = max_memory_size
|
| 429 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 430 |
+
self.batch_size = batch_size
|
| 431 |
+
|
| 432 |
+
self.exploration_rate = exploration_rate
|
| 433 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 434 |
+
self.exploration_rate_min = exploration_rate_min
|
| 435 |
+
self.gamma = gamma
|
| 436 |
+
|
| 437 |
+
self.curr_step = 0
|
| 438 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 439 |
+
|
| 440 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 441 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 442 |
+
|
| 443 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 444 |
+
self.save_dir = save_dir
|
| 445 |
+
|
| 446 |
+
self.use_cuda = torch.cuda.is_available()
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
| 450 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 451 |
+
# Q_target parameters are frozen.
|
| 452 |
+
for p in self.target_net.parameters():
|
| 453 |
+
p.requires_grad = False
|
| 454 |
+
|
| 455 |
+
if self.use_cuda:
|
| 456 |
+
self.online_net = self.online_net(device='cuda')
|
| 457 |
+
self.target_net = self.target_net(device='cuda')
|
| 458 |
+
if checkpoint:
|
| 459 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 460 |
+
|
| 461 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
| 462 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 463 |
+
# self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
|
| 464 |
+
# self.loss_fn = torch.nn.MSELoss()
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def act(self, state):
|
| 468 |
+
"""
|
| 469 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 470 |
+
|
| 471 |
+
Inputs:
|
| 472 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 473 |
+
Outputs:
|
| 474 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 475 |
+
"""
|
| 476 |
+
# EXPLORE
|
| 477 |
+
if np.random.rand() < self.exploration_rate:
|
| 478 |
+
action_idx = np.random.randint(self.action_dim)
|
| 479 |
+
|
| 480 |
+
# EXPLOIT
|
| 481 |
+
else:
|
| 482 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 483 |
+
state = state.unsqueeze(0)
|
| 484 |
+
action_values = self.online_net(state)
|
| 485 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 486 |
+
|
| 487 |
+
# decrease exploration_rate
|
| 488 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 489 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 490 |
+
|
| 491 |
+
# increment step
|
| 492 |
+
self.curr_step += 1
|
| 493 |
+
return action_idx
|
| 494 |
+
|
| 495 |
+
def cache(self, state, next_state, action, reward, done):
|
| 496 |
+
"""
|
| 497 |
+
Store the experience to self.memory (replay buffer)
|
| 498 |
+
|
| 499 |
+
Inputs:
|
| 500 |
+
state (LazyFrame),
|
| 501 |
+
next_state (LazyFrame),
|
| 502 |
+
action (int),
|
| 503 |
+
reward (float),
|
| 504 |
+
done(bool))
|
| 505 |
+
"""
|
| 506 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 507 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 508 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 509 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 510 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 511 |
+
|
| 512 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def recall(self):
|
| 516 |
+
"""
|
| 517 |
+
Retrieve a batch of experiences from memory
|
| 518 |
+
"""
|
| 519 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 520 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 521 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def td_estimate(self, states, actions):
|
| 525 |
+
actions = actions.reshape(-1, 1)
|
| 526 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
| 527 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 528 |
+
return predicted_qs
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@torch.no_grad()
|
| 532 |
+
def td_target(self, rewards, next_states, dones):
|
| 533 |
+
rewards = rewards.reshape(-1, 1)
|
| 534 |
+
dones = dones.reshape(-1, 1)
|
| 535 |
+
target_qs = self.target_net.forward(next_states)
|
| 536 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 537 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 538 |
+
target_qs[dones] = 0.0
|
| 539 |
+
return (rewards + (self.gamma * target_qs))
|
| 540 |
+
|
| 541 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 542 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
| 543 |
+
self.optimizer.zero_grad()
|
| 544 |
+
loss.backward()
|
| 545 |
+
self.optimizer.step()
|
| 546 |
+
return loss.item()
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def sync_Q_target(self):
|
| 550 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def learn(self):
|
| 554 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 555 |
+
self.sync_Q_target()
|
| 556 |
+
|
| 557 |
+
if self.curr_step % self.save_every == 0:
|
| 558 |
+
self.save()
|
| 559 |
+
|
| 560 |
+
if self.curr_step < self.learning_starts:
|
| 561 |
+
return None, None
|
| 562 |
+
|
| 563 |
+
if self.curr_step % self.training_frequency != 0:
|
| 564 |
+
return None, None
|
| 565 |
+
|
| 566 |
+
# Sample from memory
|
| 567 |
+
state, next_state, action, reward, done = self.recall()
|
| 568 |
+
|
| 569 |
+
# Get TD Estimate
|
| 570 |
+
td_est = self.td_estimate(state, action)
|
| 571 |
+
|
| 572 |
+
# Get TD Target
|
| 573 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 574 |
+
|
| 575 |
+
# Backpropagate loss through Q_online
|
| 576 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 577 |
+
|
| 578 |
+
return (td_est.mean().item(), loss)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def save(self):
|
| 582 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 583 |
+
torch.save(
|
| 584 |
+
dict(
|
| 585 |
+
model=self.online_net.state_dict(),
|
| 586 |
+
exploration_rate=self.exploration_rate,
|
| 587 |
+
replay_memory=self.memory
|
| 588 |
+
),
|
| 589 |
+
save_path
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 596 |
+
if not load_path.exists():
|
| 597 |
+
raise ValueError(f"{load_path} does not exist")
|
| 598 |
+
|
| 599 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 600 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 601 |
+
state_dict = ckp.get('model')
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 605 |
+
self.online_net.load_state_dict(state_dict)
|
| 606 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 607 |
+
self.sync_Q_target()
|
| 608 |
+
|
| 609 |
+
if load_replay_buffer:
|
| 610 |
+
replay_memory = ckp.get('replay_memory')
|
| 611 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 612 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 613 |
+
|
| 614 |
+
if reset_exploration_rate:
|
| 615 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 616 |
+
else:
|
| 617 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 618 |
+
self.exploration_rate = exploration_rate
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class DuelingDDQNAgent(DuelingDQNAgent):
|
| 624 |
+
@torch.no_grad()
|
| 625 |
+
def td_target(self, rewards, next_states, dones):
|
| 626 |
+
rewards = rewards.reshape(-1, 1)
|
| 627 |
+
dones = dones.reshape(-1, 1)
|
| 628 |
+
q_vals = self.online_net.forward(next_states)
|
| 629 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 630 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 631 |
+
|
| 632 |
+
target_qs = self.target_net.forward(next_states)
|
| 633 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 634 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 635 |
+
target_qs[dones] = 0.0
|
| 636 |
+
return (rewards + (self.gamma * target_qs))
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
class DQNAgentWithStepDecay:
|
| 641 |
+
def __init__(self,
|
| 642 |
+
state_dim,
|
| 643 |
+
action_dim,
|
| 644 |
+
save_dir,
|
| 645 |
+
checkpoint=None,
|
| 646 |
+
learning_rate=0.00025,
|
| 647 |
+
max_memory_size=100000,
|
| 648 |
+
batch_size=32,
|
| 649 |
+
exploration_rate=1,
|
| 650 |
+
exploration_rate_decay=0.9999999,
|
| 651 |
+
exploration_rate_min=0.1,
|
| 652 |
+
training_frequency=1,
|
| 653 |
+
learning_starts=1000,
|
| 654 |
+
target_network_sync_frequency=500,
|
| 655 |
+
reset_exploration_rate=False,
|
| 656 |
+
save_frequency=100000,
|
| 657 |
+
gamma=0.9,
|
| 658 |
+
load_replay_buffer=True):
|
| 659 |
+
self.state_dim = state_dim
|
| 660 |
+
self.action_dim = action_dim
|
| 661 |
+
self.max_memory_size = max_memory_size
|
| 662 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 663 |
+
self.batch_size = batch_size
|
| 664 |
+
|
| 665 |
+
self.exploration_rate = exploration_rate
|
| 666 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 667 |
+
self.exploration_rate_min = exploration_rate_min
|
| 668 |
+
self.gamma = gamma
|
| 669 |
+
|
| 670 |
+
self.curr_step = 0
|
| 671 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 672 |
+
|
| 673 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 674 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 675 |
+
|
| 676 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 677 |
+
self.save_dir = save_dir
|
| 678 |
+
|
| 679 |
+
self.use_cuda = torch.cuda.is_available()
|
| 680 |
+
|
| 681 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
| 682 |
+
if self.use_cuda:
|
| 683 |
+
self.net = self.net.to(device='cuda')
|
| 684 |
+
if checkpoint:
|
| 685 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 686 |
+
|
| 687 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
| 688 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 689 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
|
| 690 |
+
# self.loss_fn = torch.nn.MSELoss()
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def act(self, state):
|
| 694 |
+
"""
|
| 695 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 696 |
+
|
| 697 |
+
Inputs:
|
| 698 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 699 |
+
Outputs:
|
| 700 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 701 |
+
"""
|
| 702 |
+
# EXPLORE
|
| 703 |
+
if np.random.rand() < self.exploration_rate:
|
| 704 |
+
action_idx = np.random.randint(self.action_dim)
|
| 705 |
+
|
| 706 |
+
# EXPLOIT
|
| 707 |
+
else:
|
| 708 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 709 |
+
state = state.unsqueeze(0)
|
| 710 |
+
action_values = self.net(state, model='online')
|
| 711 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 712 |
+
|
| 713 |
+
# decrease exploration_rate
|
| 714 |
+
|
| 715 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 716 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 717 |
+
|
| 718 |
+
# increment step
|
| 719 |
+
self.curr_step += 1
|
| 720 |
+
return action_idx
|
| 721 |
+
|
| 722 |
+
def cache(self, state, next_state, action, reward, done, stepnumber):
|
| 723 |
+
"""
|
| 724 |
+
Store the experience to self.memory (replay buffer)
|
| 725 |
+
|
| 726 |
+
Inputs:
|
| 727 |
+
state (LazyFrame),
|
| 728 |
+
next_state (LazyFrame),
|
| 729 |
+
action (int),
|
| 730 |
+
reward (float),
|
| 731 |
+
done(bool))
|
| 732 |
+
"""
|
| 733 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 734 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 735 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 736 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 737 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 738 |
+
stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
|
| 739 |
+
|
| 740 |
+
self.memory.append( (state, next_state, action, reward, done, stepnumber) )
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def recall(self):
|
| 744 |
+
"""
|
| 745 |
+
Retrieve a batch of experiences from memory
|
| 746 |
+
"""
|
| 747 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 748 |
+
state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
|
| 749 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def td_estimate(self, states, actions):
|
| 753 |
+
actions = actions.reshape(-1, 1)
|
| 754 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 755 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 756 |
+
return predicted_qs
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
@torch.no_grad()
|
| 760 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
| 761 |
+
rewards = rewards.reshape(-1, 1)
|
| 762 |
+
dones = dones.reshape(-1, 1)
|
| 763 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
| 764 |
+
target_qs = self.net(next_states, model='target')
|
| 765 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 766 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 767 |
+
target_qs[dones] = 0.0
|
| 768 |
+
discount = ((200 - stepnumbers)/200)
|
| 769 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
| 770 |
+
return (rewards + val)
|
| 771 |
+
|
| 772 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 773 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
| 774 |
+
self.optimizer.zero_grad()
|
| 775 |
+
loss.backward()
|
| 776 |
+
self.optimizer.step()
|
| 777 |
+
return loss.item()
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def sync_Q_target(self):
|
| 781 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def learn(self):
|
| 785 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 786 |
+
self.sync_Q_target()
|
| 787 |
+
|
| 788 |
+
if self.curr_step % self.save_every == 0:
|
| 789 |
+
self.save()
|
| 790 |
+
|
| 791 |
+
if self.curr_step < self.learning_starts:
|
| 792 |
+
return None, None
|
| 793 |
+
|
| 794 |
+
if self.curr_step % self.training_frequency != 0:
|
| 795 |
+
return None, None
|
| 796 |
+
|
| 797 |
+
# Sample from memory
|
| 798 |
+
state, next_state, action, reward, done, stepnumber = self.recall()
|
| 799 |
+
|
| 800 |
+
# Get TD Estimate
|
| 801 |
+
td_est = self.td_estimate(state, action)
|
| 802 |
+
|
| 803 |
+
# Get TD Target
|
| 804 |
+
td_tgt = self.td_target(reward, next_state, done, stepnumber)
|
| 805 |
+
|
| 806 |
+
# Backpropagate loss through Q_online
|
| 807 |
+
|
| 808 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 809 |
+
|
| 810 |
+
return (td_est.mean().item(), loss)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def save(self):
|
| 814 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 815 |
+
torch.save(
|
| 816 |
+
dict(
|
| 817 |
+
model=self.net.state_dict(),
|
| 818 |
+
exploration_rate=self.exploration_rate,
|
| 819 |
+
replay_memory=self.memory
|
| 820 |
+
),
|
| 821 |
+
save_path
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 828 |
+
if not load_path.exists():
|
| 829 |
+
raise ValueError(f"{load_path} does not exist")
|
| 830 |
+
|
| 831 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 832 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 833 |
+
state_dict = ckp.get('model')
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 837 |
+
self.net.load_state_dict(state_dict)
|
| 838 |
+
|
| 839 |
+
if load_replay_buffer:
|
| 840 |
+
replay_memory = ckp.get('replay_memory')
|
| 841 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 842 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 843 |
+
|
| 844 |
+
if reset_exploration_rate:
|
| 845 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 846 |
+
else:
|
| 847 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 848 |
+
self.exploration_rate = exploration_rate
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
class DDQNAgentWithStepDecay(DQNAgentWithStepDecay):
|
| 852 |
+
@torch.no_grad()
|
| 853 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
| 854 |
+
rewards = rewards.reshape(-1, 1)
|
| 855 |
+
dones = dones.reshape(-1, 1)
|
| 856 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
| 857 |
+
q_vals = self.net(next_states, model='online')
|
| 858 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 859 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 860 |
+
|
| 861 |
+
target_qs = self.net(next_states, model='target')
|
| 862 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 863 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 864 |
+
target_qs[dones] = 0.0
|
| 865 |
+
discount = ((200 - stepnumbers)/200)
|
| 866 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
| 867 |
+
return (rewards + val)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class DuelingDQNAgentWithStepDecay:
|
| 871 |
+
def __init__(self,
|
| 872 |
+
state_dim,
|
| 873 |
+
action_dim,
|
| 874 |
+
save_dir,
|
| 875 |
+
checkpoint=None,
|
| 876 |
+
learning_rate=0.00025,
|
| 877 |
+
max_memory_size=100000,
|
| 878 |
+
batch_size=32,
|
| 879 |
+
exploration_rate=1,
|
| 880 |
+
exploration_rate_decay=0.9999999,
|
| 881 |
+
exploration_rate_min=0.1,
|
| 882 |
+
training_frequency=1,
|
| 883 |
+
learning_starts=1000,
|
| 884 |
+
target_network_sync_frequency=500,
|
| 885 |
+
reset_exploration_rate=False,
|
| 886 |
+
save_frequency=100000,
|
| 887 |
+
gamma=0.9,
|
| 888 |
+
load_replay_buffer=True):
|
| 889 |
+
self.state_dim = state_dim
|
| 890 |
+
self.action_dim = action_dim
|
| 891 |
+
self.max_memory_size = max_memory_size
|
| 892 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 893 |
+
self.batch_size = batch_size
|
| 894 |
+
|
| 895 |
+
self.exploration_rate = exploration_rate
|
| 896 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 897 |
+
self.exploration_rate_min = exploration_rate_min
|
| 898 |
+
self.gamma = gamma
|
| 899 |
+
|
| 900 |
+
self.curr_step = 0
|
| 901 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 902 |
+
|
| 903 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 904 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 905 |
+
|
| 906 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 907 |
+
self.save_dir = save_dir
|
| 908 |
+
|
| 909 |
+
self.use_cuda = torch.cuda.is_available()
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
| 913 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 914 |
+
# Q_target parameters are frozen.
|
| 915 |
+
for p in self.target_net.parameters():
|
| 916 |
+
p.requires_grad = False
|
| 917 |
+
|
| 918 |
+
if self.use_cuda:
|
| 919 |
+
self.online_net = self.online_net(device='cuda')
|
| 920 |
+
self.target_net = self.target_net(device='cuda')
|
| 921 |
+
if checkpoint:
|
| 922 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 923 |
+
|
| 924 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
| 925 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 926 |
+
# self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
|
| 927 |
+
# self.loss_fn = torch.nn.MSELoss()
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
def act(self, state):
|
| 931 |
+
"""
|
| 932 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 933 |
+
|
| 934 |
+
Inputs:
|
| 935 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 936 |
+
Outputs:
|
| 937 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 938 |
+
"""
|
| 939 |
+
# EXPLORE
|
| 940 |
+
if np.random.rand() < self.exploration_rate:
|
| 941 |
+
action_idx = np.random.randint(self.action_dim)
|
| 942 |
+
|
| 943 |
+
# EXPLOIT
|
| 944 |
+
else:
|
| 945 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 946 |
+
state = state.unsqueeze(0)
|
| 947 |
+
action_values = self.online_net(state)
|
| 948 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 949 |
+
|
| 950 |
+
# decrease exploration_rate
|
| 951 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 952 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 953 |
+
|
| 954 |
+
# increment step
|
| 955 |
+
self.curr_step += 1
|
| 956 |
+
return action_idx
|
| 957 |
+
|
| 958 |
+
def cache(self, state, next_state, action, reward, done, stepnumber):
|
| 959 |
+
"""
|
| 960 |
+
Store the experience to self.memory (replay buffer)
|
| 961 |
+
|
| 962 |
+
Inputs:
|
| 963 |
+
state (LazyFrame),
|
| 964 |
+
next_state (LazyFrame),
|
| 965 |
+
action (int),
|
| 966 |
+
reward (float),
|
| 967 |
+
done(bool))
|
| 968 |
+
"""
|
| 969 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 970 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 971 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 972 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 973 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 974 |
+
stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
|
| 975 |
+
|
| 976 |
+
self.memory.append( (state, next_state, action, reward, done, stepnumber) )
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def recall(self):
|
| 980 |
+
"""
|
| 981 |
+
Retrieve a batch of experiences from memory
|
| 982 |
+
"""
|
| 983 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 984 |
+
state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
|
| 985 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
def td_estimate(self, states, actions):
|
| 989 |
+
actions = actions.reshape(-1, 1)
|
| 990 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
| 991 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 992 |
+
return predicted_qs
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
@torch.no_grad()
|
| 996 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
| 997 |
+
rewards = rewards.reshape(-1, 1)
|
| 998 |
+
dones = dones.reshape(-1, 1)
|
| 999 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
| 1000 |
+
target_qs = self.target_net.forward(next_states)
|
| 1001 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 1002 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 1003 |
+
target_qs[dones] = 0.0
|
| 1004 |
+
discount = ((200 - stepnumbers)/200)
|
| 1005 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
| 1006 |
+
return (rewards + val)
|
| 1007 |
+
|
| 1008 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 1009 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
| 1010 |
+
self.optimizer.zero_grad()
|
| 1011 |
+
loss.backward()
|
| 1012 |
+
self.optimizer.step()
|
| 1013 |
+
return loss.item()
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def sync_Q_target(self):
|
| 1017 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def learn(self):
|
| 1021 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 1022 |
+
self.sync_Q_target()
|
| 1023 |
+
|
| 1024 |
+
if self.curr_step % self.save_every == 0:
|
| 1025 |
+
self.save()
|
| 1026 |
+
|
| 1027 |
+
if self.curr_step < self.learning_starts:
|
| 1028 |
+
return None, None
|
| 1029 |
+
|
| 1030 |
+
if self.curr_step % self.training_frequency != 0:
|
| 1031 |
+
return None, None
|
| 1032 |
+
|
| 1033 |
+
# Sample from memory
|
| 1034 |
+
state, next_state, action, reward, done, stepnumbers = self.recall()
|
| 1035 |
+
|
| 1036 |
+
# Get TD Estimate
|
| 1037 |
+
td_est = self.td_estimate(state, action)
|
| 1038 |
+
|
| 1039 |
+
# Get TD Target
|
| 1040 |
+
td_tgt = self.td_target(reward, next_state, done, stepnumbers)
|
| 1041 |
+
|
| 1042 |
+
# Backpropagate loss through Q_online
|
| 1043 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 1044 |
+
|
| 1045 |
+
return (td_est.mean().item(), loss)
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
def save(self):
|
| 1049 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 1050 |
+
torch.save(
|
| 1051 |
+
dict(
|
| 1052 |
+
model=self.online_net.state_dict(),
|
| 1053 |
+
exploration_rate=self.exploration_rate,
|
| 1054 |
+
replay_memory=self.memory
|
| 1055 |
+
),
|
| 1056 |
+
save_path
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 1063 |
+
if not load_path.exists():
|
| 1064 |
+
raise ValueError(f"{load_path} does not exist")
|
| 1065 |
+
|
| 1066 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 1067 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 1068 |
+
state_dict = ckp.get('model')
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 1072 |
+
self.online_net.load_state_dict(state_dict)
|
| 1073 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 1074 |
+
self.sync_Q_target()
|
| 1075 |
+
|
| 1076 |
+
if load_replay_buffer:
|
| 1077 |
+
replay_memory = ckp.get('replay_memory')
|
| 1078 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 1079 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 1080 |
+
|
| 1081 |
+
if reset_exploration_rate:
|
| 1082 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 1083 |
+
else:
|
| 1084 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 1085 |
+
self.exploration_rate = exploration_rate
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
class DuelingDDQNAgentWithStepDecay(DuelingDQNAgentWithStepDecay):
|
| 1089 |
+
@torch.no_grad()
|
| 1090 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
| 1091 |
+
rewards = rewards.reshape(-1, 1)
|
| 1092 |
+
dones = dones.reshape(-1, 1)
|
| 1093 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
| 1094 |
+
q_vals = self.online_net.forward(next_states)
|
| 1095 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 1096 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 1097 |
+
|
| 1098 |
+
target_qs = self.target_net.forward(next_states)
|
| 1099 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 1100 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 1101 |
+
target_qs[dones] = 0.0
|
| 1102 |
+
discount = ((200 - stepnumbers)/200)
|
| 1103 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
| 1104 |
+
return (rewards + val)
|
src/lunar-lander/params.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hyperparams = dict(
|
| 2 |
+
batch_size=128,
|
| 3 |
+
exploration_rate=1,
|
| 4 |
+
exploration_rate_decay=0.99999,
|
| 5 |
+
exploration_rate_min=0.01,
|
| 6 |
+
training_frequency=1,
|
| 7 |
+
target_network_sync_frequency=20,
|
| 8 |
+
max_memory_size=1000000,
|
| 9 |
+
learning_rate=0.001,
|
| 10 |
+
learning_starts=128,
|
| 11 |
+
save_frequency=100000
|
| 12 |
+
)
|
src/lunar-lander/replay.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
| 4 |
+
from wrappers import make_lunar
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
env = make_lunar()
|
| 8 |
+
|
| 9 |
+
env.reset()
|
| 10 |
+
|
| 11 |
+
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 12 |
+
save_dir.mkdir(parents=True)
|
| 13 |
+
|
| 14 |
+
# checkpoint = Path('checkpoints/lunar-lander-dueling-ddqn/airstriker_net_2.chkpt')
|
| 15 |
+
checkpoint = Path('checkpoints/lunar-lander-dqn-rc/airstriker_net_1.chkpt')
|
| 16 |
+
|
| 17 |
+
logger = MetricLogger(save_dir)
|
| 18 |
+
|
| 19 |
+
print("Testing Double DQN Agent!")
|
| 20 |
+
agent = DDQNAgent(
|
| 21 |
+
state_dim=8,
|
| 22 |
+
action_dim=env.action_space.n,
|
| 23 |
+
save_dir=save_dir,
|
| 24 |
+
batch_size=512,
|
| 25 |
+
checkpoint=checkpoint,
|
| 26 |
+
exploration_rate_decay=0.999995,
|
| 27 |
+
exploration_rate_min=0.05,
|
| 28 |
+
training_frequency=1,
|
| 29 |
+
target_network_sync_frequency=200,
|
| 30 |
+
max_memory_size=50000,
|
| 31 |
+
learning_rate=0.0005,
|
| 32 |
+
load_replay_buffer=False
|
| 33 |
+
|
| 34 |
+
)
|
| 35 |
+
agent.exploration_rate = agent.exploration_rate_min
|
| 36 |
+
|
| 37 |
+
episodes = 100
|
| 38 |
+
|
| 39 |
+
for e in range(episodes):
|
| 40 |
+
|
| 41 |
+
state = env.reset()
|
| 42 |
+
|
| 43 |
+
while True:
|
| 44 |
+
|
| 45 |
+
env.render()
|
| 46 |
+
|
| 47 |
+
action = agent.act(state)
|
| 48 |
+
|
| 49 |
+
next_state, reward, done, info = env.step(action)
|
| 50 |
+
|
| 51 |
+
# agent.cache(state, next_state, action, reward, done)
|
| 52 |
+
|
| 53 |
+
# logger.log_step(reward, None, None)
|
| 54 |
+
|
| 55 |
+
state = next_state
|
| 56 |
+
|
| 57 |
+
if done:
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
# logger.log_episode()
|
| 61 |
+
|
| 62 |
+
# if e % 20 == 0:
|
| 63 |
+
# logger.record(
|
| 64 |
+
# episode=e,
|
| 65 |
+
# epsilon=agent.exploration_rate,
|
| 66 |
+
# step=agent.curr_step
|
| 67 |
+
# )
|
src/lunar-lander/run-lunar-ddqn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DDQNAgent, DDQNAgentWithStepDecay, MetricLogger
|
| 6 |
+
from wrappers import make_lunar
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
from params import hyperparams
|
| 10 |
+
|
| 11 |
+
env = make_lunar()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/lunar-lander-ddqn-rc"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training DDQN Agent!")
|
| 29 |
+
agent = DDQNAgentWithStepDecay(
|
| 30 |
+
state_dim=8,
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
checkpoint=checkpoint,
|
| 34 |
+
**hyperparams
|
| 35 |
+
)
|
| 36 |
+
# agent = DDQNAgent(
|
| 37 |
+
# state_dim=8,
|
| 38 |
+
# action_dim=env.action_space.n,
|
| 39 |
+
# save_dir=save_dir,
|
| 40 |
+
# checkpoint=checkpoint,
|
| 41 |
+
# **hyperparams
|
| 42 |
+
# )
|
| 43 |
+
|
| 44 |
+
# fill_memory(agent, env, 5000)
|
| 45 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dqn.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DQNAgent, DQNAgentWithStepDecay, MetricLogger
|
| 6 |
+
from wrappers import make_lunar
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
from params import hyperparams
|
| 10 |
+
|
| 11 |
+
env = make_lunar()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/lunar-lander-dqn-rc"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training Vanilla DQN Agent with decay!")
|
| 29 |
+
agent = DQNAgentWithStepDecay(
|
| 30 |
+
state_dim=8,
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
checkpoint=checkpoint,
|
| 34 |
+
**hyperparams
|
| 35 |
+
)
|
| 36 |
+
# print("Training Vanilla DQN Agent!")
|
| 37 |
+
# agent = DQNAgent(
|
| 38 |
+
# state_dim=8,
|
| 39 |
+
# action_dim=env.action_space.n,
|
| 40 |
+
# save_dir=save_dir,
|
| 41 |
+
# checkpoint=checkpoint,
|
| 42 |
+
# **hyperparams
|
| 43 |
+
# )
|
| 44 |
+
|
| 45 |
+
# fill_memory(agent, env, 5000)
|
| 46 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dueling-ddqn.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DuelingDDQNAgent, DuelingDDQNAgentWithStepDecay,MetricLogger
|
| 6 |
+
from wrappers import make_lunar
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
from params import hyperparams
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
env = make_lunar()
|
| 13 |
+
|
| 14 |
+
use_cuda = torch.cuda.is_available()
|
| 15 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 16 |
+
|
| 17 |
+
checkpoint = None
|
| 18 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 19 |
+
|
| 20 |
+
path = "checkpoints/lunar-lander-dueling-ddqn-rc"
|
| 21 |
+
save_dir = Path(path)
|
| 22 |
+
|
| 23 |
+
isExist = os.path.exists(path)
|
| 24 |
+
if not isExist:
|
| 25 |
+
os.makedirs(path)
|
| 26 |
+
|
| 27 |
+
logger = MetricLogger(save_dir)
|
| 28 |
+
|
| 29 |
+
print("Training Dueling DDQN Agent with step decay!")
|
| 30 |
+
agent = DuelingDDQNAgentWithStepDecay(
|
| 31 |
+
state_dim=8,
|
| 32 |
+
action_dim=env.action_space.n,
|
| 33 |
+
save_dir=save_dir,
|
| 34 |
+
checkpoint=checkpoint,
|
| 35 |
+
**hyperparams
|
| 36 |
+
)
|
| 37 |
+
# print("Training Dueling DDQN Agent!")
|
| 38 |
+
# agent = DuelingDDQNAgent(
|
| 39 |
+
# state_dim=8,
|
| 40 |
+
# action_dim=env.action_space.n,
|
| 41 |
+
# save_dir=save_dir,
|
| 42 |
+
# checkpoint=checkpoint,
|
| 43 |
+
# **hyperparams
|
| 44 |
+
# )
|
| 45 |
+
|
| 46 |
+
# fill_memory(agent, env, 5000)
|
| 47 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dueling-dqn.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DuelingDQNAgent, DuelingDQNAgentWithStepDecay, MetricLogger
|
| 6 |
+
from wrappers import make_lunar
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
from params import hyperparams
|
| 10 |
+
|
| 11 |
+
env = make_lunar()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/lunar-lander-dueling-dqn-rc"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training Dueling DQN Agent with step decay!")
|
| 29 |
+
agent = DuelingDQNAgentWithStepDecay(
|
| 30 |
+
state_dim=8,
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
checkpoint=checkpoint,
|
| 34 |
+
**hyperparams
|
| 35 |
+
)
|
| 36 |
+
# print("Training Dueling DQN Agent!")
|
| 37 |
+
# agent = DuelingDQNAgent(
|
| 38 |
+
# state_dim=8,
|
| 39 |
+
# action_dim=env.action_space.n,
|
| 40 |
+
# save_dir=save_dir,
|
| 41 |
+
# checkpoint=checkpoint,
|
| 42 |
+
# **hyperparams
|
| 43 |
+
# )
|
| 44 |
+
|
| 45 |
+
# fill_memory(agent, env, 5000)
|
| 46 |
+
train(agent, env, logger)
|
src/lunar-lander/train.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import trange
|
| 2 |
+
|
| 3 |
+
def fill_memory(agent, env, num_episodes=500 ):
|
| 4 |
+
print("Filling up memory....")
|
| 5 |
+
for _ in trange(500):
|
| 6 |
+
state = env.reset()
|
| 7 |
+
done = False
|
| 8 |
+
while not done:
|
| 9 |
+
action = agent.act(state)
|
| 10 |
+
next_state, reward, done, _ = env.step(action)
|
| 11 |
+
agent.cache(state, next_state, action, reward, done)
|
| 12 |
+
state = next_state
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# def train(agent, env, logger):
|
| 16 |
+
# episodes = 5000
|
| 17 |
+
# for e in range(episodes):
|
| 18 |
+
|
| 19 |
+
# state = env.reset()
|
| 20 |
+
# # Play the game!
|
| 21 |
+
# while True:
|
| 22 |
+
|
| 23 |
+
# # Run agent on the state
|
| 24 |
+
# action = agent.act(state)
|
| 25 |
+
|
| 26 |
+
# # Agent performs action
|
| 27 |
+
# next_state, reward, done, info = env.step(action)
|
| 28 |
+
|
| 29 |
+
# # Remember
|
| 30 |
+
# agent.cache(state, next_state, action, reward, done)
|
| 31 |
+
|
| 32 |
+
# # Learn
|
| 33 |
+
# q, loss = agent.learn()
|
| 34 |
+
|
| 35 |
+
# # Logging
|
| 36 |
+
# logger.log_step(reward, loss, q)
|
| 37 |
+
|
| 38 |
+
# # Update state
|
| 39 |
+
# state = next_state
|
| 40 |
+
|
| 41 |
+
# # Check if end of game
|
| 42 |
+
# if done:
|
| 43 |
+
# break
|
| 44 |
+
|
| 45 |
+
# logger.log_episode(e)
|
| 46 |
+
|
| 47 |
+
# if e % 20 == 0:
|
| 48 |
+
# logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def train(agent, env, logger):
|
| 52 |
+
episodes = 5000
|
| 53 |
+
for e in range(episodes):
|
| 54 |
+
|
| 55 |
+
state = env.reset()
|
| 56 |
+
# Play the game!
|
| 57 |
+
for i in range(1000):
|
| 58 |
+
|
| 59 |
+
# Run agent on the state
|
| 60 |
+
action = agent.act(state)
|
| 61 |
+
env.render()
|
| 62 |
+
# Agent performs action
|
| 63 |
+
next_state, reward, done, info = env.step(action)
|
| 64 |
+
|
| 65 |
+
# Remember
|
| 66 |
+
agent.cache(state, next_state, action, reward, done, i)
|
| 67 |
+
|
| 68 |
+
# Learn
|
| 69 |
+
q, loss = agent.learn()
|
| 70 |
+
|
| 71 |
+
# Logging
|
| 72 |
+
logger.log_step(reward, loss, q)
|
| 73 |
+
|
| 74 |
+
# Update state
|
| 75 |
+
state = next_state
|
| 76 |
+
|
| 77 |
+
# Check if end of game
|
| 78 |
+
if done:
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
logger.log_episode(e)
|
| 82 |
+
|
| 83 |
+
if e % 20 == 0:
|
| 84 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
src/lunar-lander/wrappers.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from collections import deque
|
| 4 |
+
import gym
|
| 5 |
+
from gym import spaces
|
| 6 |
+
import cv2
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LazyFrames(object):
|
| 15 |
+
def __init__(self, frames):
|
| 16 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 17 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 18 |
+
buffers.
|
| 19 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 20 |
+
You'd not believe how complex the previous solution was."""
|
| 21 |
+
self._frames = frames
|
| 22 |
+
self._out = None
|
| 23 |
+
|
| 24 |
+
def _force(self):
|
| 25 |
+
if self._out is None:
|
| 26 |
+
self._out = np.concatenate(self._frames, axis=2)
|
| 27 |
+
self._frames = None
|
| 28 |
+
return self._out
|
| 29 |
+
|
| 30 |
+
def __array__(self, dtype=None):
|
| 31 |
+
out = self._force()
|
| 32 |
+
if dtype is not None:
|
| 33 |
+
out = out.astype(dtype)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self._force())
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, i):
|
| 40 |
+
return self._force()[i]
|
| 41 |
+
|
| 42 |
+
class FireResetEnv(gym.Wrapper):
|
| 43 |
+
def __init__(self, env):
|
| 44 |
+
"""Take action on reset for environments that are fixed until firing."""
|
| 45 |
+
gym.Wrapper.__init__(self, env)
|
| 46 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
| 47 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 48 |
+
|
| 49 |
+
def reset(self, **kwargs):
|
| 50 |
+
self.env.reset(**kwargs)
|
| 51 |
+
obs, _, done, _ = self.env.step(1)
|
| 52 |
+
if done:
|
| 53 |
+
self.env.reset(**kwargs)
|
| 54 |
+
obs, _, done, _ = self.env.step(2)
|
| 55 |
+
if done:
|
| 56 |
+
self.env.reset(**kwargs)
|
| 57 |
+
return obs
|
| 58 |
+
|
| 59 |
+
def step(self, ac):
|
| 60 |
+
return self.env.step(ac)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 64 |
+
def __init__(self, env, skip=4):
|
| 65 |
+
"""Return only every `skip`-th frame"""
|
| 66 |
+
gym.Wrapper.__init__(self, env)
|
| 67 |
+
# most recent raw observations (for max pooling across time steps)
|
| 68 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
| 69 |
+
self._skip = skip
|
| 70 |
+
|
| 71 |
+
def step(self, action):
|
| 72 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 73 |
+
total_reward = 0.0
|
| 74 |
+
done = None
|
| 75 |
+
for i in range(self._skip):
|
| 76 |
+
obs, reward, done, info = self.env.step(action)
|
| 77 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
| 78 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
| 79 |
+
total_reward += reward
|
| 80 |
+
if done:
|
| 81 |
+
break
|
| 82 |
+
# Note that the observation on the done=True frame
|
| 83 |
+
# doesn't matter
|
| 84 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 85 |
+
|
| 86 |
+
return max_frame, total_reward, done, info
|
| 87 |
+
|
| 88 |
+
def reset(self, **kwargs):
|
| 89 |
+
return self.env.reset(**kwargs)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 94 |
+
def __init__(self, env):
|
| 95 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 96 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 97 |
+
self.width = 84
|
| 98 |
+
self.height = 84
|
| 99 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
| 100 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
| 101 |
+
|
| 102 |
+
def observation(self, frame):
|
| 103 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 104 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 105 |
+
return frame[:, :, None]
|
| 106 |
+
|
| 107 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
| 108 |
+
def __init__(self, env):
|
| 109 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 110 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 111 |
+
|
| 112 |
+
def observation(self, frame):
|
| 113 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 114 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 115 |
+
return frame[:, :, None]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class FrameStack(gym.Wrapper):
|
| 120 |
+
def __init__(self, env, k):
|
| 121 |
+
"""Stack k last frames.
|
| 122 |
+
Returns lazy array, which is much more memory efficient.
|
| 123 |
+
See Also
|
| 124 |
+
--------
|
| 125 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 126 |
+
"""
|
| 127 |
+
gym.Wrapper.__init__(self, env)
|
| 128 |
+
self.k = k
|
| 129 |
+
self.frames = deque([], maxlen=k)
|
| 130 |
+
shp = env.observation_space.shape
|
| 131 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
| 132 |
+
|
| 133 |
+
def reset(self):
|
| 134 |
+
ob = self.env.reset()
|
| 135 |
+
for _ in range(self.k):
|
| 136 |
+
self.frames.append(ob)
|
| 137 |
+
return self._get_ob()
|
| 138 |
+
|
| 139 |
+
def step(self, action):
|
| 140 |
+
ob, reward, done, info = self.env.step(action)
|
| 141 |
+
self.frames.append(ob)
|
| 142 |
+
return self._get_ob(), reward, done, info
|
| 143 |
+
|
| 144 |
+
def _get_ob(self):
|
| 145 |
+
assert len(self.frames) == self.k
|
| 146 |
+
return LazyFrames(list(self.frames))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
| 150 |
+
def __init__(self, env):
|
| 151 |
+
super(ImageToPyTorch, self).__init__(env)
|
| 152 |
+
old_shape = self.observation_space.shape
|
| 153 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
| 154 |
+
|
| 155 |
+
def observation(self, observation):
|
| 156 |
+
return np.moveaxis(observation, 2, 0)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
| 160 |
+
def __init__(self, env):
|
| 161 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 162 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
| 163 |
+
|
| 164 |
+
def observation(self, observation):
|
| 165 |
+
# careful! This undoes the memory optimization, use
|
| 166 |
+
# with smaller replay buffers only.
|
| 167 |
+
return np.array(observation).astype(np.float32) / 255.0
|
| 168 |
+
|
| 169 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 170 |
+
def __init__(self, env):
|
| 171 |
+
gym.RewardWrapper.__init__(self, env)
|
| 172 |
+
|
| 173 |
+
def reward(self, reward):
|
| 174 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 175 |
+
return np.sign(reward)
|
| 176 |
+
|
| 177 |
+
class TanRewardClipperEnv(gym.RewardWrapper):
|
| 178 |
+
def __init__(self, env):
|
| 179 |
+
gym.RewardWrapper.__init__(self, env)
|
| 180 |
+
|
| 181 |
+
def reward(self, reward):
|
| 182 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 183 |
+
return 10 * math.tanh(float(reward)/30.)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def make_lunar(render=False):
|
| 187 |
+
print("Environment: Lunar Lander")
|
| 188 |
+
env = gym.make("LunarLander-v2")
|
| 189 |
+
# env = TanRewardClipperEnv(env)
|
| 190 |
+
# env = WarpFrameNoResize(env) ## Reshape image
|
| 191 |
+
# env = ImageToPyTorch(env) ## Invert shape
|
| 192 |
+
# env = FrameStack(env, 4) ## Stack last 4 frames
|
| 193 |
+
return env
|
src/procgen/agent.py
ADDED
|
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import copy
|
| 6 |
+
import time, datetime
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from collections import deque
|
| 9 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DQNet(nn.Module):
|
| 13 |
+
"""mini cnn structure
|
| 14 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, input_dim, output_dim):
|
| 18 |
+
super().__init__()
|
| 19 |
+
print("#################################")
|
| 20 |
+
print("#################################")
|
| 21 |
+
print(input_dim)
|
| 22 |
+
print(output_dim)
|
| 23 |
+
print("#################################")
|
| 24 |
+
print("#################################")
|
| 25 |
+
c, h, w = input_dim
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
self.online = nn.Sequential(
|
| 29 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
| 34 |
+
nn.ReLU(),
|
| 35 |
+
nn.Flatten(),
|
| 36 |
+
nn.Linear(7168, 512),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Linear(512, output_dim),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
self.target = copy.deepcopy(self.online)
|
| 43 |
+
|
| 44 |
+
# Q_target parameters are frozen.
|
| 45 |
+
for p in self.target.parameters():
|
| 46 |
+
p.requires_grad = False
|
| 47 |
+
|
| 48 |
+
def forward(self, input, model):
|
| 49 |
+
if model == "online":
|
| 50 |
+
return self.online(input)
|
| 51 |
+
elif model == "target":
|
| 52 |
+
return self.target(input)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class MetricLogger:
|
| 57 |
+
def __init__(self, save_dir):
|
| 58 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
| 59 |
+
self.save_log = save_dir / "log"
|
| 60 |
+
with open(self.save_log, "w") as f:
|
| 61 |
+
f.write(
|
| 62 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
| 63 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
| 64 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
| 65 |
+
)
|
| 66 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
| 67 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
| 68 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
| 69 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
| 70 |
+
|
| 71 |
+
# History metrics
|
| 72 |
+
self.ep_rewards = []
|
| 73 |
+
self.ep_lengths = []
|
| 74 |
+
self.ep_avg_losses = []
|
| 75 |
+
self.ep_avg_qs = []
|
| 76 |
+
|
| 77 |
+
# Moving averages, added for every call to record()
|
| 78 |
+
self.moving_avg_ep_rewards = []
|
| 79 |
+
self.moving_avg_ep_lengths = []
|
| 80 |
+
self.moving_avg_ep_avg_losses = []
|
| 81 |
+
self.moving_avg_ep_avg_qs = []
|
| 82 |
+
|
| 83 |
+
# Current episode metric
|
| 84 |
+
self.init_episode()
|
| 85 |
+
|
| 86 |
+
# Timing
|
| 87 |
+
self.record_time = time.time()
|
| 88 |
+
|
| 89 |
+
def log_step(self, reward, loss, q):
|
| 90 |
+
self.curr_ep_reward += reward
|
| 91 |
+
self.curr_ep_length += 1
|
| 92 |
+
if loss:
|
| 93 |
+
self.curr_ep_loss += loss
|
| 94 |
+
self.curr_ep_q += q
|
| 95 |
+
self.curr_ep_loss_length += 1
|
| 96 |
+
|
| 97 |
+
def log_episode(self, episode_number):
|
| 98 |
+
"Mark end of episode"
|
| 99 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
| 100 |
+
self.ep_lengths.append(self.curr_ep_length)
|
| 101 |
+
if self.curr_ep_loss_length == 0:
|
| 102 |
+
ep_avg_loss = 0
|
| 103 |
+
ep_avg_q = 0
|
| 104 |
+
else:
|
| 105 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
| 106 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
| 107 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
| 108 |
+
self.ep_avg_qs.append(ep_avg_q)
|
| 109 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
| 110 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
| 111 |
+
self.writer.flush()
|
| 112 |
+
self.init_episode()
|
| 113 |
+
|
| 114 |
+
def init_episode(self):
|
| 115 |
+
self.curr_ep_reward = 0.0
|
| 116 |
+
self.curr_ep_length = 0
|
| 117 |
+
self.curr_ep_loss = 0.0
|
| 118 |
+
self.curr_ep_q = 0.0
|
| 119 |
+
self.curr_ep_loss_length = 0
|
| 120 |
+
|
| 121 |
+
def record(self, episode, epsilon, step):
|
| 122 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
| 123 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
| 124 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
| 125 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
| 126 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
| 127 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
| 128 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
| 129 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
| 130 |
+
|
| 131 |
+
last_record_time = self.record_time
|
| 132 |
+
self.record_time = time.time()
|
| 133 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
| 134 |
+
|
| 135 |
+
print(
|
| 136 |
+
f"Episode {episode} - "
|
| 137 |
+
f"Step {step} - "
|
| 138 |
+
f"Epsilon {epsilon} - "
|
| 139 |
+
f"Mean Reward {mean_ep_reward} - "
|
| 140 |
+
f"Mean Length {mean_ep_length} - "
|
| 141 |
+
f"Mean Loss {mean_ep_loss} - "
|
| 142 |
+
f"Mean Q Value {mean_ep_q} - "
|
| 143 |
+
f"Time Delta {time_since_last_record} - "
|
| 144 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
| 145 |
+
)
|
| 146 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 147 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
| 148 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
| 149 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
| 150 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
| 151 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
| 152 |
+
self.writer.flush()
|
| 153 |
+
with open(self.save_log, "a") as f:
|
| 154 |
+
f.write(
|
| 155 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
| 156 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
| 157 |
+
f"{time_since_last_record:15.3f}"
|
| 158 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
| 162 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
| 163 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
| 164 |
+
plt.clf()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DQNAgent:
|
| 168 |
+
def __init__(self,
|
| 169 |
+
state_dim,
|
| 170 |
+
action_dim,
|
| 171 |
+
save_dir,
|
| 172 |
+
checkpoint=None,
|
| 173 |
+
learning_rate=0.00025,
|
| 174 |
+
max_memory_size=100000,
|
| 175 |
+
batch_size=32,
|
| 176 |
+
exploration_rate=1,
|
| 177 |
+
exploration_rate_decay=0.9999999,
|
| 178 |
+
exploration_rate_min=0.1,
|
| 179 |
+
training_frequency=1,
|
| 180 |
+
learning_starts=1000,
|
| 181 |
+
target_network_sync_frequency=500,
|
| 182 |
+
reset_exploration_rate=False,
|
| 183 |
+
save_frequency=100000,
|
| 184 |
+
gamma=0.9,
|
| 185 |
+
load_replay_buffer=True):
|
| 186 |
+
self.state_dim = state_dim
|
| 187 |
+
self.action_dim = action_dim
|
| 188 |
+
self.max_memory_size = max_memory_size
|
| 189 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 190 |
+
self.batch_size = batch_size
|
| 191 |
+
|
| 192 |
+
self.exploration_rate = exploration_rate
|
| 193 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 194 |
+
self.exploration_rate_min = exploration_rate_min
|
| 195 |
+
self.gamma = gamma
|
| 196 |
+
|
| 197 |
+
self.curr_step = 0
|
| 198 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 199 |
+
|
| 200 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 201 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 202 |
+
|
| 203 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 204 |
+
self.save_dir = save_dir
|
| 205 |
+
|
| 206 |
+
self.use_cuda = torch.cuda.is_available()
|
| 207 |
+
|
| 208 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
| 209 |
+
if self.use_cuda:
|
| 210 |
+
self.net = self.net.to(device='cuda')
|
| 211 |
+
if checkpoint:
|
| 212 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 213 |
+
|
| 214 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
| 215 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def act(self, state):
|
| 219 |
+
"""
|
| 220 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 221 |
+
|
| 222 |
+
Inputs:
|
| 223 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 224 |
+
Outputs:
|
| 225 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 226 |
+
"""
|
| 227 |
+
# EXPLORE
|
| 228 |
+
if np.random.rand() < self.exploration_rate:
|
| 229 |
+
action_idx = np.random.randint(self.action_dim)
|
| 230 |
+
|
| 231 |
+
# EXPLOIT
|
| 232 |
+
else:
|
| 233 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 234 |
+
state = state.unsqueeze(0)
|
| 235 |
+
action_values = self.net(state, model='online')
|
| 236 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 237 |
+
|
| 238 |
+
# decrease exploration_rate
|
| 239 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 240 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 241 |
+
|
| 242 |
+
# increment step
|
| 243 |
+
self.curr_step += 1
|
| 244 |
+
return action_idx
|
| 245 |
+
|
| 246 |
+
def cache(self, state, next_state, action, reward, done):
|
| 247 |
+
"""
|
| 248 |
+
Store the experience to self.memory (replay buffer)
|
| 249 |
+
|
| 250 |
+
Inputs:
|
| 251 |
+
state (LazyFrame),
|
| 252 |
+
next_state (LazyFrame),
|
| 253 |
+
action (int),
|
| 254 |
+
reward (float),
|
| 255 |
+
done(bool))
|
| 256 |
+
"""
|
| 257 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 258 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 259 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 260 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 261 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 262 |
+
|
| 263 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def recall(self):
|
| 267 |
+
"""
|
| 268 |
+
Retrieve a batch of experiences from memory
|
| 269 |
+
"""
|
| 270 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 271 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 272 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def td_estimate(self, states, actions):
|
| 276 |
+
actions = actions.reshape(-1, 1)
|
| 277 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
| 278 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 279 |
+
return predicted_qs
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@torch.no_grad()
|
| 283 |
+
def td_target(self, rewards, next_states, dones):
|
| 284 |
+
rewards = rewards.reshape(-1, 1)
|
| 285 |
+
dones = dones.reshape(-1, 1)
|
| 286 |
+
target_qs = self.net(next_states, model='target')
|
| 287 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 288 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 289 |
+
target_qs[dones] = 0.0
|
| 290 |
+
return (rewards + (self.gamma * target_qs))
|
| 291 |
+
|
| 292 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 293 |
+
loss = self.loss_fn(td_estimate, td_target)
|
| 294 |
+
self.optimizer.zero_grad()
|
| 295 |
+
loss.backward()
|
| 296 |
+
self.optimizer.step()
|
| 297 |
+
return loss.item()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def sync_Q_target(self):
|
| 301 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def learn(self):
|
| 305 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 306 |
+
self.sync_Q_target()
|
| 307 |
+
|
| 308 |
+
if self.curr_step % self.save_every == 0:
|
| 309 |
+
self.save()
|
| 310 |
+
|
| 311 |
+
if self.curr_step < self.learning_starts:
|
| 312 |
+
return None, None
|
| 313 |
+
|
| 314 |
+
if self.curr_step % self.training_frequency != 0:
|
| 315 |
+
return None, None
|
| 316 |
+
|
| 317 |
+
# Sample from memory
|
| 318 |
+
state, next_state, action, reward, done = self.recall()
|
| 319 |
+
|
| 320 |
+
# Get TD Estimate
|
| 321 |
+
td_est = self.td_estimate(state, action)
|
| 322 |
+
|
| 323 |
+
# Get TD Target
|
| 324 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 325 |
+
|
| 326 |
+
# Backpropagate loss through Q_online
|
| 327 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 328 |
+
|
| 329 |
+
return (td_est.mean().item(), loss)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def save(self):
|
| 333 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 334 |
+
torch.save(
|
| 335 |
+
dict(
|
| 336 |
+
model=self.net.state_dict(),
|
| 337 |
+
exploration_rate=self.exploration_rate,
|
| 338 |
+
replay_memory=self.memory
|
| 339 |
+
),
|
| 340 |
+
save_path
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 347 |
+
if not load_path.exists():
|
| 348 |
+
raise ValueError(f"{load_path} does not exist")
|
| 349 |
+
|
| 350 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 351 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 352 |
+
state_dict = ckp.get('model')
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 356 |
+
self.net.load_state_dict(state_dict)
|
| 357 |
+
|
| 358 |
+
if load_replay_buffer:
|
| 359 |
+
replay_memory = ckp.get('replay_memory')
|
| 360 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 361 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 362 |
+
|
| 363 |
+
if reset_exploration_rate:
|
| 364 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 365 |
+
else:
|
| 366 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 367 |
+
self.exploration_rate = exploration_rate
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class DDQNAgent(DQNAgent):
|
| 371 |
+
@torch.no_grad()
|
| 372 |
+
def td_target(self, rewards, next_states, dones):
|
| 373 |
+
rewards = rewards.reshape(-1, 1)
|
| 374 |
+
dones = dones.reshape(-1, 1)
|
| 375 |
+
q_vals = self.net(next_states, model='online')
|
| 376 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 377 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 378 |
+
|
| 379 |
+
target_qs = self.net(next_states, model='target')
|
| 380 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 381 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 382 |
+
target_qs[dones] = 0.0
|
| 383 |
+
return (rewards + (self.gamma * target_qs))
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class DuelingDQNet(nn.Module):
|
| 387 |
+
def __init__(self, input_dim, output_dim):
|
| 388 |
+
super().__init__()
|
| 389 |
+
print("#################################")
|
| 390 |
+
print("#################################")
|
| 391 |
+
print(input_dim)
|
| 392 |
+
print(output_dim)
|
| 393 |
+
print("#################################")
|
| 394 |
+
print("#################################")
|
| 395 |
+
c, h, w = input_dim
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
self.conv_layer = nn.Sequential(
|
| 399 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
| 400 |
+
nn.ReLU(),
|
| 401 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
| 402 |
+
nn.ReLU(),
|
| 403 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
| 404 |
+
nn.ReLU(),
|
| 405 |
+
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
self.value_layer = nn.Sequential(
|
| 410 |
+
nn.Linear(7168, 128),
|
| 411 |
+
nn.ReLU(),
|
| 412 |
+
nn.Linear(128, 1)
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
self.advantage_layer = nn.Sequential(
|
| 416 |
+
nn.Linear(7168, 128),
|
| 417 |
+
nn.ReLU(),
|
| 418 |
+
nn.Linear(128, output_dim)
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
def forward(self, state):
|
| 422 |
+
conv_output = self.conv_layer(state)
|
| 423 |
+
conv_output = conv_output.view(conv_output.size(0), -1)
|
| 424 |
+
value = self.value_layer(conv_output)
|
| 425 |
+
advantage = self.advantage_layer(conv_output)
|
| 426 |
+
q_value = value + (advantage - advantage.mean())
|
| 427 |
+
|
| 428 |
+
return q_value
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class DuelingDQNAgent:
|
| 432 |
+
def __init__(self,
|
| 433 |
+
state_dim,
|
| 434 |
+
action_dim,
|
| 435 |
+
save_dir,
|
| 436 |
+
checkpoint=None,
|
| 437 |
+
learning_rate=0.00025,
|
| 438 |
+
max_memory_size=100000,
|
| 439 |
+
batch_size=32,
|
| 440 |
+
exploration_rate=1,
|
| 441 |
+
exploration_rate_decay=0.9999999,
|
| 442 |
+
exploration_rate_min=0.1,
|
| 443 |
+
training_frequency=1,
|
| 444 |
+
learning_starts=1000,
|
| 445 |
+
target_network_sync_frequency=500,
|
| 446 |
+
reset_exploration_rate=False,
|
| 447 |
+
save_frequency=100000,
|
| 448 |
+
gamma=0.9,
|
| 449 |
+
load_replay_buffer=True):
|
| 450 |
+
self.state_dim = state_dim
|
| 451 |
+
self.action_dim = action_dim
|
| 452 |
+
self.max_memory_size = max_memory_size
|
| 453 |
+
self.memory = deque(maxlen=max_memory_size)
|
| 454 |
+
self.batch_size = batch_size
|
| 455 |
+
|
| 456 |
+
self.exploration_rate = exploration_rate
|
| 457 |
+
self.exploration_rate_decay = exploration_rate_decay
|
| 458 |
+
self.exploration_rate_min = exploration_rate_min
|
| 459 |
+
self.gamma = gamma
|
| 460 |
+
|
| 461 |
+
self.curr_step = 0
|
| 462 |
+
self.learning_starts = learning_starts # min. experiences before training
|
| 463 |
+
|
| 464 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
| 465 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
| 466 |
+
|
| 467 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
| 468 |
+
self.save_dir = save_dir
|
| 469 |
+
|
| 470 |
+
self.use_cuda = torch.cuda.is_available()
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
| 474 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 475 |
+
# Q_target parameters are frozen.
|
| 476 |
+
for p in self.target_net.parameters():
|
| 477 |
+
p.requires_grad = False
|
| 478 |
+
|
| 479 |
+
if self.use_cuda:
|
| 480 |
+
self.online_net = self.online_net(device='cuda')
|
| 481 |
+
self.target_net = self.target_net(device='cuda')
|
| 482 |
+
if checkpoint:
|
| 483 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
| 484 |
+
|
| 485 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
| 486 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def act(self, state):
|
| 490 |
+
"""
|
| 491 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
| 492 |
+
|
| 493 |
+
Inputs:
|
| 494 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
| 495 |
+
Outputs:
|
| 496 |
+
action_idx (int): An integer representing which action the agent will perform
|
| 497 |
+
"""
|
| 498 |
+
# EXPLORE
|
| 499 |
+
if np.random.rand() < self.exploration_rate:
|
| 500 |
+
action_idx = np.random.randint(self.action_dim)
|
| 501 |
+
|
| 502 |
+
# EXPLOIT
|
| 503 |
+
else:
|
| 504 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 505 |
+
state = state.unsqueeze(0)
|
| 506 |
+
action_values = self.online_net(state)
|
| 507 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
| 508 |
+
|
| 509 |
+
# decrease exploration_rate
|
| 510 |
+
self.exploration_rate *= self.exploration_rate_decay
|
| 511 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
| 512 |
+
|
| 513 |
+
# increment step
|
| 514 |
+
self.curr_step += 1
|
| 515 |
+
return action_idx
|
| 516 |
+
|
| 517 |
+
def cache(self, state, next_state, action, reward, done):
|
| 518 |
+
"""
|
| 519 |
+
Store the experience to self.memory (replay buffer)
|
| 520 |
+
|
| 521 |
+
Inputs:
|
| 522 |
+
state (LazyFrame),
|
| 523 |
+
next_state (LazyFrame),
|
| 524 |
+
action (int),
|
| 525 |
+
reward (float),
|
| 526 |
+
done(bool))
|
| 527 |
+
"""
|
| 528 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
| 529 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
| 530 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
| 531 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
| 532 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
| 533 |
+
|
| 534 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def recall(self):
|
| 538 |
+
"""
|
| 539 |
+
Retrieve a batch of experiences from memory
|
| 540 |
+
"""
|
| 541 |
+
batch = random.sample(self.memory, self.batch_size)
|
| 542 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
| 543 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def td_estimate(self, states, actions):
|
| 547 |
+
actions = actions.reshape(-1, 1)
|
| 548 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
| 549 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
| 550 |
+
return predicted_qs
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@torch.no_grad()
|
| 554 |
+
def td_target(self, rewards, next_states, dones):
|
| 555 |
+
rewards = rewards.reshape(-1, 1)
|
| 556 |
+
dones = dones.reshape(-1, 1)
|
| 557 |
+
target_qs = self.target_net.forward(next_states)
|
| 558 |
+
target_qs = torch.max(target_qs, dim=1).values
|
| 559 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 560 |
+
target_qs[dones] = 0.0
|
| 561 |
+
return (rewards + (self.gamma * target_qs))
|
| 562 |
+
|
| 563 |
+
def update_Q_online(self, td_estimate, td_target) :
|
| 564 |
+
loss = self.loss_fn(td_estimate, td_target)
|
| 565 |
+
self.optimizer.zero_grad()
|
| 566 |
+
loss.backward()
|
| 567 |
+
self.optimizer.step()
|
| 568 |
+
return loss.item()
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def sync_Q_target(self):
|
| 572 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def learn(self):
|
| 576 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
| 577 |
+
self.sync_Q_target()
|
| 578 |
+
|
| 579 |
+
if self.curr_step % self.save_every == 0:
|
| 580 |
+
self.save()
|
| 581 |
+
|
| 582 |
+
if self.curr_step < self.learning_starts:
|
| 583 |
+
return None, None
|
| 584 |
+
|
| 585 |
+
if self.curr_step % self.training_frequency != 0:
|
| 586 |
+
return None, None
|
| 587 |
+
|
| 588 |
+
# Sample from memory
|
| 589 |
+
state, next_state, action, reward, done = self.recall()
|
| 590 |
+
|
| 591 |
+
# Get TD Estimate
|
| 592 |
+
td_est = self.td_estimate(state, action)
|
| 593 |
+
|
| 594 |
+
# Get TD Target
|
| 595 |
+
td_tgt = self.td_target(reward, next_state, done)
|
| 596 |
+
|
| 597 |
+
# Backpropagate loss through Q_online
|
| 598 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
| 599 |
+
|
| 600 |
+
return (td_est.mean().item(), loss)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def save(self):
|
| 604 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
| 605 |
+
torch.save(
|
| 606 |
+
dict(
|
| 607 |
+
model=self.online_net.state_dict(),
|
| 608 |
+
exploration_rate=self.exploration_rate,
|
| 609 |
+
replay_memory=self.memory
|
| 610 |
+
),
|
| 611 |
+
save_path
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
| 618 |
+
if not load_path.exists():
|
| 619 |
+
raise ValueError(f"{load_path} does not exist")
|
| 620 |
+
|
| 621 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
| 622 |
+
exploration_rate = ckp.get('exploration_rate')
|
| 623 |
+
state_dict = ckp.get('model')
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
| 627 |
+
self.online_net.load_state_dict(state_dict)
|
| 628 |
+
self.target_net = copy.deepcopy(self.online_net)
|
| 629 |
+
self.sync_Q_target()
|
| 630 |
+
|
| 631 |
+
if load_replay_buffer:
|
| 632 |
+
replay_memory = ckp.get('replay_memory')
|
| 633 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
| 634 |
+
self.memory = replay_memory if replay_memory else self.memory
|
| 635 |
+
|
| 636 |
+
if reset_exploration_rate:
|
| 637 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
| 638 |
+
else:
|
| 639 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
| 640 |
+
self.exploration_rate = exploration_rate
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class DuelingDDQNAgent(DuelingDQNAgent):
|
| 646 |
+
@torch.no_grad()
|
| 647 |
+
def td_target(self, rewards, next_states, dones):
|
| 648 |
+
rewards = rewards.reshape(-1, 1)
|
| 649 |
+
dones = dones.reshape(-1, 1)
|
| 650 |
+
q_vals = self.online_net.forward(next_states)
|
| 651 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
| 652 |
+
target_actions = target_actions.reshape(-1, 1)
|
| 653 |
+
|
| 654 |
+
target_qs = self.target_net.forward(next_states)
|
| 655 |
+
target_qs = target_qs.gather(1, target_actions)
|
| 656 |
+
target_qs = target_qs.reshape(-1, 1)
|
| 657 |
+
target_qs[dones] = 0.0
|
| 658 |
+
return (rewards + (self.gamma * target_qs))
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
|
src/procgen/run-starpilot-ddqn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DDQNAgent, MetricLogger
|
| 6 |
+
from wrappers import make_starpilot
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
env = make_starpilot()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/procgen-starpilot-ddqn"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training DDQN Agent!")
|
| 29 |
+
agent = DDQNAgent(
|
| 30 |
+
state_dim=(1, 64, 64),
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
batch_size=256,
|
| 34 |
+
checkpoint=checkpoint,
|
| 35 |
+
exploration_rate_decay=0.999995,
|
| 36 |
+
exploration_rate_min=0.05,
|
| 37 |
+
training_frequency=1,
|
| 38 |
+
target_network_sync_frequency=200,
|
| 39 |
+
max_memory_size=50000,
|
| 40 |
+
learning_rate=0.0005,
|
| 41 |
+
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
fill_memory(agent, env, 300)
|
| 45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dqn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DQNAgent, MetricLogger
|
| 6 |
+
from wrappers import make_starpilot
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
env = make_starpilot()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/procgen-starpilot-dqn"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training Vanilla DQN Agent!")
|
| 29 |
+
agent = DQNAgent(
|
| 30 |
+
state_dim=(1, 64, 64),
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
batch_size=256,
|
| 34 |
+
checkpoint=checkpoint,
|
| 35 |
+
exploration_rate_decay=0.999995,
|
| 36 |
+
exploration_rate_min=0.05,
|
| 37 |
+
training_frequency=1,
|
| 38 |
+
target_network_sync_frequency=200,
|
| 39 |
+
max_memory_size=50000,
|
| 40 |
+
learning_rate=0.0005,
|
| 41 |
+
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
fill_memory(agent, env, 300)
|
| 45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dueling-ddqn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DuelingDDQNAgent, MetricLogger
|
| 6 |
+
from wrappers import make_starpilot
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
env = make_starpilot()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/procgen-starpilot-dueling-ddqn"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training Dueling Double DQN Agent!")
|
| 29 |
+
agent = DuelingDDQNAgent(
|
| 30 |
+
state_dim=(1, 64, 64),
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
batch_size=256,
|
| 34 |
+
checkpoint=checkpoint,
|
| 35 |
+
exploration_rate_decay=0.999995,
|
| 36 |
+
exploration_rate_min=0.05,
|
| 37 |
+
training_frequency=1,
|
| 38 |
+
target_network_sync_frequency=200,
|
| 39 |
+
max_memory_size=50000,
|
| 40 |
+
learning_rate=0.0005,
|
| 41 |
+
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# fill_memory(agent, env, 300)
|
| 45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dueling-dqn.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from agent import DuelingDQNAgent, MetricLogger
|
| 6 |
+
from wrappers import make_starpilot
|
| 7 |
+
import os
|
| 8 |
+
from train import train, fill_memory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
env = make_starpilot()
|
| 12 |
+
|
| 13 |
+
use_cuda = torch.cuda.is_available()
|
| 14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
| 15 |
+
|
| 16 |
+
checkpoint = None
|
| 17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
| 18 |
+
|
| 19 |
+
path = "checkpoints/procgen-starpilot-dueling-dqn"
|
| 20 |
+
save_dir = Path(path)
|
| 21 |
+
|
| 22 |
+
isExist = os.path.exists(path)
|
| 23 |
+
if not isExist:
|
| 24 |
+
os.makedirs(path)
|
| 25 |
+
|
| 26 |
+
logger = MetricLogger(save_dir)
|
| 27 |
+
|
| 28 |
+
print("Training Dueling DQN Agent!")
|
| 29 |
+
agent = DuelingDQNAgent(
|
| 30 |
+
state_dim=(1, 64, 64),
|
| 31 |
+
action_dim=env.action_space.n,
|
| 32 |
+
save_dir=save_dir,
|
| 33 |
+
batch_size=256,
|
| 34 |
+
checkpoint=checkpoint,
|
| 35 |
+
exploration_rate_decay=0.999995,
|
| 36 |
+
exploration_rate_min=0.05,
|
| 37 |
+
training_frequency=1,
|
| 38 |
+
target_network_sync_frequency=200,
|
| 39 |
+
max_memory_size=50000,
|
| 40 |
+
learning_rate=0.0005,
|
| 41 |
+
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# fill_memory(agent, env, 300)
|
| 45 |
+
train(agent, env, logger)
|
src/procgen/test-procgen.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym
|
| 2 |
+
env = gym.make("procgen:procgen-starpilot-v0")
|
| 3 |
+
|
| 4 |
+
obs = env.reset()
|
| 5 |
+
step = 0
|
| 6 |
+
while True:
|
| 7 |
+
obs, rew, done, info = env.step(env.action_space.sample())
|
| 8 |
+
print(info)
|
| 9 |
+
print(f"step {step} reward {rew} done {done}")
|
| 10 |
+
step += 1
|
| 11 |
+
if done:
|
| 12 |
+
break
|
src/procgen/train.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tqdm import trange
|
| 2 |
+
|
| 3 |
+
def fill_memory(agent, env, num_episodes=500 ):
|
| 4 |
+
print("Filling up memory....")
|
| 5 |
+
for _ in trange(num_episodes):
|
| 6 |
+
state = env.reset()
|
| 7 |
+
done = False
|
| 8 |
+
while not done:
|
| 9 |
+
action = agent.act(state)
|
| 10 |
+
next_state, reward, done, _ = env.step(action)
|
| 11 |
+
agent.cache(state, next_state, action, reward, done)
|
| 12 |
+
state = next_state
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train(agent, env, logger):
|
| 16 |
+
episodes = 5000
|
| 17 |
+
for e in range(episodes):
|
| 18 |
+
|
| 19 |
+
state = env.reset()
|
| 20 |
+
# Play the game!
|
| 21 |
+
while True:
|
| 22 |
+
|
| 23 |
+
# Run agent on the state
|
| 24 |
+
action = agent.act(state)
|
| 25 |
+
|
| 26 |
+
# Agent performs action
|
| 27 |
+
next_state, reward, done, info = env.step(action)
|
| 28 |
+
|
| 29 |
+
# Remember
|
| 30 |
+
agent.cache(state, next_state, action, reward, done)
|
| 31 |
+
|
| 32 |
+
# Learn
|
| 33 |
+
q, loss = agent.learn()
|
| 34 |
+
|
| 35 |
+
# Logging
|
| 36 |
+
logger.log_step(reward, loss, q)
|
| 37 |
+
|
| 38 |
+
# Update state
|
| 39 |
+
state = next_state
|
| 40 |
+
|
| 41 |
+
# Check if end of game
|
| 42 |
+
if done:
|
| 43 |
+
break
|
| 44 |
+
|
| 45 |
+
logger.log_episode(e)
|
| 46 |
+
|
| 47 |
+
if e % 20 == 0:
|
| 48 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
src/procgen/wrappers.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from collections import deque
|
| 4 |
+
import gym
|
| 5 |
+
from gym import spaces
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LazyFrames(object):
|
| 15 |
+
def __init__(self, frames):
|
| 16 |
+
"""This object ensures that common frames between the observations are only stored once.
|
| 17 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
| 18 |
+
buffers.
|
| 19 |
+
This object should only be converted to numpy array before being passed to the model.
|
| 20 |
+
You'd not believe how complex the previous solution was."""
|
| 21 |
+
self._frames = frames
|
| 22 |
+
self._out = None
|
| 23 |
+
|
| 24 |
+
def _force(self):
|
| 25 |
+
if self._out is None:
|
| 26 |
+
self._out = np.concatenate(self._frames, axis=2)
|
| 27 |
+
self._frames = None
|
| 28 |
+
return self._out
|
| 29 |
+
|
| 30 |
+
def __array__(self, dtype=None):
|
| 31 |
+
out = self._force()
|
| 32 |
+
if dtype is not None:
|
| 33 |
+
out = out.astype(dtype)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self._force())
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, i):
|
| 40 |
+
return self._force()[i]
|
| 41 |
+
|
| 42 |
+
class FireResetEnv(gym.Wrapper):
|
| 43 |
+
def __init__(self, env):
|
| 44 |
+
"""Take action on reset for environments that are fixed until firing."""
|
| 45 |
+
gym.Wrapper.__init__(self, env)
|
| 46 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
| 47 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
| 48 |
+
|
| 49 |
+
def reset(self, **kwargs):
|
| 50 |
+
self.env.reset(**kwargs)
|
| 51 |
+
obs, _, done, _ = self.env.step(1)
|
| 52 |
+
if done:
|
| 53 |
+
self.env.reset(**kwargs)
|
| 54 |
+
obs, _, done, _ = self.env.step(2)
|
| 55 |
+
if done:
|
| 56 |
+
self.env.reset(**kwargs)
|
| 57 |
+
return obs
|
| 58 |
+
|
| 59 |
+
def step(self, ac):
|
| 60 |
+
return self.env.step(ac)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
| 64 |
+
def __init__(self, env, skip=4):
|
| 65 |
+
"""Return only every `skip`-th frame"""
|
| 66 |
+
gym.Wrapper.__init__(self, env)
|
| 67 |
+
# most recent raw observations (for max pooling across time steps)
|
| 68 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
| 69 |
+
self._skip = skip
|
| 70 |
+
|
| 71 |
+
def step(self, action):
|
| 72 |
+
"""Repeat action, sum reward, and max over last observations."""
|
| 73 |
+
total_reward = 0.0
|
| 74 |
+
done = None
|
| 75 |
+
for i in range(self._skip):
|
| 76 |
+
obs, reward, done, info = self.env.step(action)
|
| 77 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
| 78 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
| 79 |
+
total_reward += reward
|
| 80 |
+
if done:
|
| 81 |
+
break
|
| 82 |
+
# Note that the observation on the done=True frame
|
| 83 |
+
# doesn't matter
|
| 84 |
+
max_frame = self._obs_buffer.max(axis=0)
|
| 85 |
+
|
| 86 |
+
return max_frame, total_reward, done, info
|
| 87 |
+
|
| 88 |
+
def reset(self, **kwargs):
|
| 89 |
+
return self.env.reset(**kwargs)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class WarpFrame(gym.ObservationWrapper):
|
| 94 |
+
def __init__(self, env):
|
| 95 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 96 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 97 |
+
self.width = 84
|
| 98 |
+
self.height = 84
|
| 99 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
| 100 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
| 101 |
+
|
| 102 |
+
def observation(self, frame):
|
| 103 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 104 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 105 |
+
return frame[:, :, None]
|
| 106 |
+
|
| 107 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
| 108 |
+
def __init__(self, env):
|
| 109 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
| 110 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 111 |
+
|
| 112 |
+
def observation(self, frame):
|
| 113 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 114 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
| 115 |
+
return frame[:, :, None]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class FrameStack(gym.Wrapper):
|
| 120 |
+
def __init__(self, env, k):
|
| 121 |
+
"""Stack k last frames.
|
| 122 |
+
Returns lazy array, which is much more memory efficient.
|
| 123 |
+
See Also
|
| 124 |
+
--------
|
| 125 |
+
baselines.common.atari_wrappers.LazyFrames
|
| 126 |
+
"""
|
| 127 |
+
gym.Wrapper.__init__(self, env)
|
| 128 |
+
self.k = k
|
| 129 |
+
self.frames = deque([], maxlen=k)
|
| 130 |
+
shp = env.observation_space.shape
|
| 131 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
| 132 |
+
|
| 133 |
+
def reset(self):
|
| 134 |
+
ob = self.env.reset()
|
| 135 |
+
for _ in range(self.k):
|
| 136 |
+
self.frames.append(ob)
|
| 137 |
+
return self._get_ob()
|
| 138 |
+
|
| 139 |
+
def step(self, action):
|
| 140 |
+
ob, reward, done, info = self.env.step(action)
|
| 141 |
+
self.frames.append(ob)
|
| 142 |
+
return self._get_ob(), reward, done, info
|
| 143 |
+
|
| 144 |
+
def _get_ob(self):
|
| 145 |
+
assert len(self.frames) == self.k
|
| 146 |
+
return LazyFrames(list(self.frames))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
| 150 |
+
def __init__(self, env):
|
| 151 |
+
super(ImageToPyTorch, self).__init__(env)
|
| 152 |
+
old_shape = self.observation_space.shape
|
| 153 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
| 154 |
+
|
| 155 |
+
def observation(self, observation):
|
| 156 |
+
return np.moveaxis(observation, 2, 0)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
| 160 |
+
def __init__(self, env):
|
| 161 |
+
gym.ObservationWrapper.__init__(self, env)
|
| 162 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
| 163 |
+
|
| 164 |
+
def observation(self, observation):
|
| 165 |
+
# careful! This undoes the memory optimization, use
|
| 166 |
+
# with smaller replay buffers only.
|
| 167 |
+
return np.array(observation).astype(np.float32) / 255.0
|
| 168 |
+
|
| 169 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
| 170 |
+
def __init__(self, env):
|
| 171 |
+
gym.RewardWrapper.__init__(self, env)
|
| 172 |
+
|
| 173 |
+
def reward(self, reward):
|
| 174 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
| 175 |
+
return np.sign(reward)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def make_starpilot(render=False):
|
| 179 |
+
print("Environment: Starpilot")
|
| 180 |
+
if render:
|
| 181 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
|
| 182 |
+
else:
|
| 183 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
|
| 184 |
+
env = WarpFrameNoResize(env) ## Reshape image
|
| 185 |
+
env = ImageToPyTorch(env) ## Invert shape
|
| 186 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
| 187 |
+
return env
|
troubleshooting.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ml-reinforcement-learning
|
| 2 |
+
|
| 3 |
+
Python version: 3.7.3
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
Troubleshooting
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
- RuntimeError: Polyfit sanity test emitted a warning, most likely due to using a buggy Accelerate backend. If you compiled yourself, more information is available at https://numpy.org/doc/stable/user/building.html#accelerated-blas-lapack-libraries Otherwise report this to the vendor that provided NumPy.
|
| 10 |
+
RankWarning: Polyfit may be poorly conditioned
|
| 11 |
+
|
| 12 |
+
```
|
| 13 |
+
$ pip uninstall numpy
|
| 14 |
+
$ export OPENBLAS=$(brew --prefix openblas)
|
| 15 |
+
$ pip install --no-cache-dir numpy
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
During grpcio installation 👇
|
| 20 |
+
distutils.errors.CompileError: command 'clang' failed with exit status 1
|
| 21 |
+
```
|
| 22 |
+
CFLAGS="-I/Library/Developer/CommandLineTools/usr/include/c++/v1 -I/opt/homebrew/opt/openssl/include" LDFLAGS="-L/opt/homebrew/opt/openssl/lib" pip3 install grpcio
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
ModuleNotFoundError: No module named 'gym.envs.classic_control.rendering'
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
#Setup
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
conda install pytorch torchvision -c pytorch
|
| 33 |
+
pip install gym-retro
|
| 34 |
+
conda install numpy
|
| 35 |
+
pip install "gym[atari]==0.21.0"
|
| 36 |
+
pip install importlib-metadata==4.13.0
|
| 37 |
+
```
|