bkhan2000 commited on
Commit
cd47b9c
·
1 Parent(s): aebc38d

Upload . with huggingface_hub

Browse files
Files changed (6) hide show
  1. Policy_Gradient_PyTorch.ipynb +1395 -0
  2. README.md +27 -0
  3. hyperparameters.json +1 -0
  4. model.pt +3 -0
  5. replay.mp4 +0 -0
  6. results.json +1 -0
Policy_Gradient_PyTorch.ipynb ADDED
@@ -0,0 +1,1395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "<pyvirtualdisplay.display.Display at 0x7f6b781a3c70>"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "# Virtual display\n",
21
+ "from pyvirtualdisplay import Display\n",
22
+ "\n",
23
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
24
+ "virtual_display.start()"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "import numpy as np\n",
34
+ "\n",
35
+ "from collections import deque\n",
36
+ "\n",
37
+ "import matplotlib.pyplot as plt\n",
38
+ "%matplotlib inline\n",
39
+ "\n",
40
+ "# PyTorch\n",
41
+ "import torch\n",
42
+ "import torch.nn as nn\n",
43
+ "import torch.nn.functional as F\n",
44
+ "import torch.optim as optim\n",
45
+ "from torch.distributions import Categorical\n",
46
+ "\n",
47
+ "# Gym\n",
48
+ "import gym\n",
49
+ "import gym_pygame\n",
50
+ "\n",
51
+ "# Hugging Face Hub\n",
52
+ "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n",
53
+ "import imageio\n",
54
+ "# imageio: A library that will help us to generate a replay video"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 4,
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "cuda:0\n"
67
+ ]
68
+ }
69
+ ],
70
+ "source": [
71
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
72
+ "print(device)"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "### Cartpole-v1"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 6,
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "env_id = \"CartPole-v1\"\n",
89
+ "env = gym.make(env_id)\n",
90
+ "\n",
91
+ "# evaluation env\n",
92
+ "eval_env = gym.make(env_id)\n",
93
+ "\n",
94
+ "s_size = env.observation_space.shape[0]\n",
95
+ "a_size = env.action_space.n"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": 7,
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "_____OBSERVATION SPACE_____ \n",
108
+ "\n",
109
+ "The State Space is: 4\n",
110
+ "Sample observation [-2.6818509e+00 2.6710869e+38 -2.7456334e-01 4.6941264e+37]\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
116
+ "print(\"The State Space is: \", s_size)\n",
117
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 8,
123
+ "metadata": {},
124
+ "outputs": [
125
+ {
126
+ "name": "stdout",
127
+ "output_type": "stream",
128
+ "text": [
129
+ "\n",
130
+ " _____ACTION SPACE_____ \n",
131
+ "\n",
132
+ "The Action Space is: 2\n",
133
+ "Action Space Sample 0\n"
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
139
+ "print(\"The Action Space is: \", a_size)\n",
140
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "markdown",
145
+ "metadata": {},
146
+ "source": [
147
+ "### Reinforce Archtecture"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 13,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "class Policy(nn.Module):\n",
157
+ " def __init__(self, s_size, a_size, h_size):\n",
158
+ " super(Policy, self).__init__()\n",
159
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
160
+ " self.fc2 = nn.Linear(h_size, a_size)\n",
161
+ " \n",
162
+ " def forward(self, x):\n",
163
+ " x = F.relu(self.fc1(x))\n",
164
+ " x = self.fc2(x)\n",
165
+ " return F.softmax(x, dim=1)\n",
166
+ "\n",
167
+ " def act(self, state):\n",
168
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
169
+ " probs = self.forward(state).cpu()\n",
170
+ " m = Categorical(probs)\n",
171
+ " # action = np.argmax(m)\n",
172
+ " action = m.sample()\n",
173
+ " return action.item(), m.log_prob(action)\n",
174
+ " "
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 14,
180
+ "metadata": {},
181
+ "outputs": [
182
+ {
183
+ "data": {
184
+ "text/plain": [
185
+ "(1, tensor([-0.7983], grad_fn=<SqueezeBackward1>))"
186
+ ]
187
+ },
188
+ "execution_count": 14,
189
+ "metadata": {},
190
+ "output_type": "execute_result"
191
+ }
192
+ ],
193
+ "source": [
194
+ "debug_policy = Policy(s_size, a_size, 64).to(device)\n",
195
+ "debug_policy.act(env.reset())"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {},
201
+ "source": [
202
+ "<img src=https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit6/pg_pseudocode.png/>"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 15,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n",
212
+ " scores_deque = deque(maxlen=100)\n",
213
+ " scores = []\n",
214
+ "\n",
215
+ " # Line 3 of pseudocode\n",
216
+ " for i_episodes in range(1, n_training_episodes+1):\n",
217
+ " saved_log_probs = []\n",
218
+ " rewards = []\n",
219
+ " state = env.reset()\n",
220
+ "\n",
221
+ " # Line 4 of pseudocode\n",
222
+ " for i_episode in range(1, n_training_episodes):\n",
223
+ " action, log_prob = policy.act(state)\n",
224
+ " saved_log_probs.append(log_prob)\n",
225
+ " state, reward, done, _ = env.step(action)\n",
226
+ " rewards.append(reward)\n",
227
+ " if done:\n",
228
+ " break\n",
229
+ " scores_deque.append(sum(rewards))\n",
230
+ " scores.append(sum(rewards))\n",
231
+ "\n",
232
+ " # Line 6 of pseudocode\n",
233
+ " returns = deque(maxlen=max_t)\n",
234
+ " n_steps = len(rewards)\n",
235
+ "\n",
236
+ " for t in range(n_steps)[::-1]:\n",
237
+ " disc_return_t = (returns[0] if len(returns)>0 else 0)\n",
238
+ " returns.appendleft(gamma * disc_return_t + rewards[t])\n",
239
+ "\n",
240
+ " eps = np.finfo(np.float32).eps.item()\n",
241
+ "\n",
242
+ " returns = torch.tensor(returns)\n",
243
+ " returns = (returns - returns.mean()) / (returns.std() + eps)\n",
244
+ "\n",
245
+ " # Line 7\n",
246
+ " policy_loss = []\n",
247
+ " for log_prob, disc_return in zip(saved_log_probs, returns):\n",
248
+ " policy_loss.append(-log_prob * disc_return)\n",
249
+ " policy_loss = torch.cat(policy_loss).sum()\n",
250
+ "\n",
251
+ " # Line 8\n",
252
+ " optimizer.zero_grad()\n",
253
+ " policy_loss.backward()\n",
254
+ " optimizer.step()\n",
255
+ "\n",
256
+ " if i_episode % print_every == 0:\n",
257
+ " print(\"Episode {}\\tAverage Score: {:.2f}\".format(i_episode, np.mean(scores_deque)))\n",
258
+ "\n",
259
+ " return scores"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 16,
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "cartpole_hyperparameters = {\n",
269
+ " \"h_size\": 16,\n",
270
+ " \"n_training_episodes\": 1000,\n",
271
+ " \"n_evaluation_episodes\": 10,\n",
272
+ " \"max_t\": 1000,\n",
273
+ " \"gamma\": 1.0,\n",
274
+ " \"lr\": 1e-2,\n",
275
+ " \"env_id\": env_id,\n",
276
+ " \"state_space\": s_size,\n",
277
+ " \"action_space\": a_size,\n",
278
+ "}"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 17,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "cartpole_policy = Policy(\n",
288
+ " cartpole_hyperparameters[\"state_space\"],\n",
289
+ " cartpole_hyperparameters[\"action_space\"],\n",
290
+ " cartpole_hyperparameters[\"h_size\"],\n",
291
+ ").to(device)\n",
292
+ "cartpole_optimizer = optim.Adam(cartpole_policy.parameters(), lr=cartpole_hyperparameters[\"lr\"])"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 18,
298
+ "metadata": {},
299
+ "outputs": [
300
+ {
301
+ "name": "stdout",
302
+ "output_type": "stream",
303
+ "text": [
304
+ "Episode 500\tAverage Score: 116.93\n",
305
+ "Episode 500\tAverage Score: 134.13\n",
306
+ "Episode 500\tAverage Score: 138.92\n",
307
+ "Episode 500\tAverage Score: 143.73\n",
308
+ "Episode 500\tAverage Score: 150.68\n",
309
+ "Episode 500\tAverage Score: 154.91\n",
310
+ "Episode 500\tAverage Score: 159.05\n",
311
+ "Episode 500\tAverage Score: 163.41\n",
312
+ "Episode 500\tAverage Score: 167.91\n",
313
+ "Episode 500\tAverage Score: 172.49\n",
314
+ "Episode 500\tAverage Score: 176.90\n",
315
+ "Episode 500\tAverage Score: 181.63\n",
316
+ "Episode 500\tAverage Score: 185.66\n",
317
+ "Episode 500\tAverage Score: 190.18\n",
318
+ "Episode 500\tAverage Score: 194.90\n",
319
+ "Episode 500\tAverage Score: 199.15\n",
320
+ "Episode 500\tAverage Score: 203.89\n",
321
+ "Episode 500\tAverage Score: 208.33\n",
322
+ "Episode 500\tAverage Score: 212.64\n",
323
+ "Episode 500\tAverage Score: 217.48\n",
324
+ "Episode 500\tAverage Score: 221.51\n",
325
+ "Episode 500\tAverage Score: 226.20\n",
326
+ "Episode 500\tAverage Score: 230.63\n",
327
+ "Episode 500\tAverage Score: 235.21\n",
328
+ "Episode 500\tAverage Score: 243.17\n",
329
+ "Episode 500\tAverage Score: 250.87\n",
330
+ "Episode 500\tAverage Score: 254.48\n",
331
+ "Episode 500\tAverage Score: 258.01\n",
332
+ "Episode 500\tAverage Score: 262.76\n",
333
+ "Episode 500\tAverage Score: 267.27\n",
334
+ "Episode 500\tAverage Score: 271.85\n",
335
+ "Episode 500\tAverage Score: 275.57\n",
336
+ "Episode 500\tAverage Score: 281.62\n",
337
+ "Episode 500\tAverage Score: 284.87\n",
338
+ "Episode 500\tAverage Score: 289.12\n",
339
+ "Episode 500\tAverage Score: 295.51\n",
340
+ "Episode 500\tAverage Score: 299.59\n",
341
+ "Episode 500\tAverage Score: 303.39\n",
342
+ "Episode 500\tAverage Score: 310.17\n",
343
+ "Episode 500\tAverage Score: 313.95\n",
344
+ "Episode 500\tAverage Score: 317.26\n",
345
+ "Episode 500\tAverage Score: 318.30\n",
346
+ "Episode 500\tAverage Score: 322.61\n",
347
+ "Episode 500\tAverage Score: 327.74\n",
348
+ "Episode 500\tAverage Score: 331.85\n",
349
+ "Episode 500\tAverage Score: 335.04\n",
350
+ "Episode 500\tAverage Score: 339.34\n",
351
+ "Episode 500\tAverage Score: 343.40\n",
352
+ "Episode 500\tAverage Score: 345.81\n",
353
+ "Episode 500\tAverage Score: 348.98\n",
354
+ "Episode 500\tAverage Score: 352.50\n",
355
+ "Episode 500\tAverage Score: 356.47\n",
356
+ "Episode 500\tAverage Score: 360.60\n",
357
+ "Episode 500\tAverage Score: 364.78\n",
358
+ "Episode 500\tAverage Score: 368.87\n",
359
+ "Episode 500\tAverage Score: 372.04\n",
360
+ "Episode 500\tAverage Score: 374.21\n",
361
+ "Episode 500\tAverage Score: 376.52\n",
362
+ "Episode 500\tAverage Score: 379.97\n",
363
+ "Episode 500\tAverage Score: 382.65\n",
364
+ "Episode 500\tAverage Score: 384.00\n",
365
+ "Episode 500\tAverage Score: 386.29\n",
366
+ "Episode 500\tAverage Score: 391.30\n",
367
+ "Episode 500\tAverage Score: 394.40\n",
368
+ "Episode 500\tAverage Score: 398.01\n",
369
+ "Episode 500\tAverage Score: 400.75\n",
370
+ "Episode 500\tAverage Score: 404.74\n",
371
+ "Episode 500\tAverage Score: 408.86\n",
372
+ "Episode 500\tAverage Score: 412.89\n",
373
+ "Episode 500\tAverage Score: 417.54\n",
374
+ "Episode 500\tAverage Score: 421.40\n",
375
+ "Episode 500\tAverage Score: 425.71\n",
376
+ "Episode 500\tAverage Score: 425.96\n",
377
+ "Episode 500\tAverage Score: 430.19\n",
378
+ "Episode 500\tAverage Score: 434.20\n",
379
+ "Episode 500\tAverage Score: 434.40\n",
380
+ "Episode 500\tAverage Score: 438.51\n",
381
+ "Episode 500\tAverage Score: 441.44\n",
382
+ "Episode 500\tAverage Score: 445.65\n",
383
+ "Episode 500\tAverage Score: 448.57\n",
384
+ "Episode 500\tAverage Score: 451.66\n",
385
+ "Episode 500\tAverage Score: 455.92\n",
386
+ "Episode 500\tAverage Score: 458.06\n",
387
+ "Episode 500\tAverage Score: 460.77\n",
388
+ "Episode 500\tAverage Score: 460.77\n",
389
+ "Episode 500\tAverage Score: 462.53\n",
390
+ "Episode 500\tAverage Score: 463.35\n",
391
+ "Episode 500\tAverage Score: 465.71\n",
392
+ "Episode 500\tAverage Score: 467.43\n",
393
+ "Episode 500\tAverage Score: 471.61\n",
394
+ "Episode 500\tAverage Score: 471.61\n",
395
+ "Episode 500\tAverage Score: 471.61\n",
396
+ "Episode 500\tAverage Score: 471.61\n",
397
+ "Episode 500\tAverage Score: 474.02\n",
398
+ "Episode 500\tAverage Score: 474.02\n",
399
+ "Episode 500\tAverage Score: 474.02\n",
400
+ "Episode 500\tAverage Score: 474.02\n",
401
+ "Episode 500\tAverage Score: 474.02\n",
402
+ "Episode 500\tAverage Score: 474.02\n",
403
+ "Episode 500\tAverage Score: 474.02\n",
404
+ "Episode 500\tAverage Score: 474.02\n",
405
+ "Episode 500\tAverage Score: 474.02\n",
406
+ "Episode 500\tAverage Score: 474.02\n",
407
+ "Episode 500\tAverage Score: 474.02\n",
408
+ "Episode 500\tAverage Score: 474.02\n",
409
+ "Episode 500\tAverage Score: 474.02\n",
410
+ "Episode 500\tAverage Score: 474.02\n",
411
+ "Episode 500\tAverage Score: 474.02\n",
412
+ "Episode 500\tAverage Score: 474.02\n",
413
+ "Episode 500\tAverage Score: 474.02\n",
414
+ "Episode 500\tAverage Score: 474.02\n",
415
+ "Episode 500\tAverage Score: 469.92\n",
416
+ "Episode 500\tAverage Score: 466.39\n",
417
+ "Episode 500\tAverage Score: 470.74\n",
418
+ "Episode 500\tAverage Score: 470.74\n",
419
+ "Episode 500\tAverage Score: 472.07\n",
420
+ "Episode 500\tAverage Score: 472.07\n",
421
+ "Episode 500\tAverage Score: 476.40\n",
422
+ "Episode 500\tAverage Score: 476.40\n",
423
+ "Episode 500\tAverage Score: 476.40\n",
424
+ "Episode 500\tAverage Score: 476.40\n",
425
+ "Episode 500\tAverage Score: 476.40\n",
426
+ "Episode 500\tAverage Score: 476.40\n",
427
+ "Episode 500\tAverage Score: 476.40\n",
428
+ "Episode 500\tAverage Score: 479.20\n",
429
+ "Episode 500\tAverage Score: 475.32\n",
430
+ "Episode 500\tAverage Score: 472.31\n",
431
+ "Episode 500\tAverage Score: 472.31\n",
432
+ "Episode 500\tAverage Score: 472.31\n",
433
+ "Episode 500\tAverage Score: 470.49\n",
434
+ "Episode 500\tAverage Score: 470.49\n",
435
+ "Episode 500\tAverage Score: 470.49\n",
436
+ "Episode 500\tAverage Score: 470.49\n",
437
+ "Episode 500\tAverage Score: 466.40\n",
438
+ "Episode 500\tAverage Score: 468.61\n",
439
+ "Episode 500\tAverage Score: 468.61\n",
440
+ "Episode 500\tAverage Score: 468.61\n",
441
+ "Episode 500\tAverage Score: 468.61\n",
442
+ "Episode 500\tAverage Score: 468.61\n",
443
+ "Episode 500\tAverage Score: 468.61\n",
444
+ "Episode 500\tAverage Score: 468.61\n",
445
+ "Episode 500\tAverage Score: 468.61\n",
446
+ "Episode 500\tAverage Score: 468.61\n",
447
+ "Episode 500\tAverage Score: 472.51\n",
448
+ "Episode 500\tAverage Score: 472.51\n",
449
+ "Episode 500\tAverage Score: 472.51\n",
450
+ "Episode 500\tAverage Score: 467.72\n",
451
+ "Episode 500\tAverage Score: 467.72\n",
452
+ "Episode 500\tAverage Score: 462.94\n",
453
+ "Episode 500\tAverage Score: 462.94\n",
454
+ "Episode 500\tAverage Score: 462.94\n",
455
+ "Episode 500\tAverage Score: 462.94\n",
456
+ "Episode 500\tAverage Score: 462.94\n",
457
+ "Episode 500\tAverage Score: 465.15\n",
458
+ "Episode 500\tAverage Score: 465.15\n",
459
+ "Episode 500\tAverage Score: 465.15\n",
460
+ "Episode 500\tAverage Score: 465.15\n",
461
+ "Episode 500\tAverage Score: 465.15\n",
462
+ "Episode 500\tAverage Score: 465.15\n",
463
+ "Episode 500\tAverage Score: 465.15\n",
464
+ "Episode 500\tAverage Score: 465.15\n",
465
+ "Episode 500\tAverage Score: 465.15\n",
466
+ "Episode 500\tAverage Score: 465.15\n",
467
+ "Episode 500\tAverage Score: 465.15\n",
468
+ "Episode 500\tAverage Score: 465.15\n",
469
+ "Episode 500\tAverage Score: 465.15\n",
470
+ "Episode 500\tAverage Score: 465.15\n",
471
+ "Episode 500\tAverage Score: 465.15\n",
472
+ "Episode 500\tAverage Score: 465.15\n",
473
+ "Episode 500\tAverage Score: 465.15\n",
474
+ "Episode 500\tAverage Score: 465.15\n",
475
+ "Episode 500\tAverage Score: 465.15\n",
476
+ "Episode 500\tAverage Score: 465.15\n",
477
+ "Episode 500\tAverage Score: 465.15\n",
478
+ "Episode 500\tAverage Score: 465.15\n",
479
+ "Episode 500\tAverage Score: 465.15\n",
480
+ "Episode 500\tAverage Score: 465.15\n",
481
+ "Episode 500\tAverage Score: 465.15\n",
482
+ "Episode 500\tAverage Score: 465.15\n",
483
+ "Episode 500\tAverage Score: 465.15\n",
484
+ "Episode 500\tAverage Score: 465.15\n",
485
+ "Episode 500\tAverage Score: 465.15\n",
486
+ "Episode 500\tAverage Score: 465.15\n",
487
+ "Episode 500\tAverage Score: 465.15\n",
488
+ "Episode 500\tAverage Score: 465.15\n",
489
+ "Episode 500\tAverage Score: 465.15\n",
490
+ "Episode 500\tAverage Score: 465.15\n",
491
+ "Episode 500\tAverage Score: 465.15\n",
492
+ "Episode 500\tAverage Score: 465.15\n",
493
+ "Episode 500\tAverage Score: 465.15\n",
494
+ "Episode 500\tAverage Score: 465.15\n",
495
+ "Episode 500\tAverage Score: 465.15\n",
496
+ "Episode 500\tAverage Score: 465.15\n",
497
+ "Episode 500\tAverage Score: 465.15\n",
498
+ "Episode 500\tAverage Score: 465.15\n",
499
+ "Episode 500\tAverage Score: 465.15\n",
500
+ "Episode 500\tAverage Score: 465.15\n",
501
+ "Episode 500\tAverage Score: 465.15\n",
502
+ "Episode 500\tAverage Score: 465.15\n",
503
+ "Episode 500\tAverage Score: 465.15\n",
504
+ "Episode 500\tAverage Score: 465.15\n",
505
+ "Episode 500\tAverage Score: 465.15\n",
506
+ "Episode 500\tAverage Score: 465.15\n",
507
+ "Episode 500\tAverage Score: 469.25\n",
508
+ "Episode 500\tAverage Score: 469.25\n",
509
+ "Episode 500\tAverage Score: 473.97\n",
510
+ "Episode 500\tAverage Score: 473.97\n",
511
+ "Episode 500\tAverage Score: 473.97\n",
512
+ "Episode 500\tAverage Score: 473.97\n",
513
+ "Episode 500\tAverage Score: 473.97\n",
514
+ "Episode 500\tAverage Score: 473.97\n",
515
+ "Episode 500\tAverage Score: 473.97\n",
516
+ "Episode 500\tAverage Score: 473.97\n",
517
+ "Episode 500\tAverage Score: 473.97\n",
518
+ "Episode 500\tAverage Score: 473.97\n",
519
+ "Episode 500\tAverage Score: 473.97\n",
520
+ "Episode 500\tAverage Score: 473.97\n",
521
+ "Episode 500\tAverage Score: 473.97\n",
522
+ "Episode 500\tAverage Score: 473.97\n",
523
+ "Episode 500\tAverage Score: 477.85\n",
524
+ "Episode 500\tAverage Score: 477.85\n",
525
+ "Episode 500\tAverage Score: 482.59\n",
526
+ "Episode 500\tAverage Score: 482.59\n",
527
+ "Episode 500\tAverage Score: 482.59\n",
528
+ "Episode 500\tAverage Score: 482.59\n",
529
+ "Episode 500\tAverage Score: 486.34\n",
530
+ "Episode 500\tAverage Score: 486.34\n",
531
+ "Episode 500\tAverage Score: 486.34\n",
532
+ "Episode 500\tAverage Score: 486.34\n",
533
+ "Episode 500\tAverage Score: 486.34\n",
534
+ "Episode 500\tAverage Score: 490.43\n",
535
+ "Episode 500\tAverage Score: 490.43\n",
536
+ "Episode 500\tAverage Score: 490.43\n",
537
+ "Episode 500\tAverage Score: 490.43\n",
538
+ "Episode 500\tAverage Score: 490.43\n",
539
+ "Episode 500\tAverage Score: 490.43\n",
540
+ "Episode 500\tAverage Score: 490.43\n",
541
+ "Episode 500\tAverage Score: 490.43\n",
542
+ "Episode 500\tAverage Score: 490.43\n",
543
+ "Episode 500\tAverage Score: 490.43\n",
544
+ "Episode 500\tAverage Score: 490.43\n",
545
+ "Episode 500\tAverage Score: 490.43\n",
546
+ "Episode 500\tAverage Score: 490.43\n",
547
+ "Episode 500\tAverage Score: 490.43\n",
548
+ "Episode 500\tAverage Score: 495.22\n",
549
+ "Episode 500\tAverage Score: 495.22\n",
550
+ "Episode 500\tAverage Score: 495.22\n",
551
+ "Episode 500\tAverage Score: 500.00\n",
552
+ "Episode 500\tAverage Score: 497.42\n",
553
+ "Episode 500\tAverage Score: 497.42\n",
554
+ "Episode 500\tAverage Score: 497.42\n",
555
+ "Episode 500\tAverage Score: 497.42\n",
556
+ "Episode 500\tAverage Score: 497.42\n",
557
+ "Episode 500\tAverage Score: 497.42\n",
558
+ "Episode 500\tAverage Score: 497.42\n",
559
+ "Episode 500\tAverage Score: 497.42\n",
560
+ "Episode 500\tAverage Score: 497.42\n",
561
+ "Episode 500\tAverage Score: 497.42\n",
562
+ "Episode 500\tAverage Score: 497.42\n",
563
+ "Episode 500\tAverage Score: 497.42\n",
564
+ "Episode 500\tAverage Score: 497.42\n",
565
+ "Episode 500\tAverage Score: 497.42\n",
566
+ "Episode 500\tAverage Score: 497.42\n",
567
+ "Episode 500\tAverage Score: 497.42\n",
568
+ "Episode 500\tAverage Score: 497.42\n",
569
+ "Episode 500\tAverage Score: 497.42\n",
570
+ "Episode 500\tAverage Score: 497.42\n",
571
+ "Episode 500\tAverage Score: 497.42\n",
572
+ "Episode 500\tAverage Score: 497.42\n",
573
+ "Episode 500\tAverage Score: 497.42\n",
574
+ "Episode 500\tAverage Score: 497.42\n",
575
+ "Episode 500\tAverage Score: 497.42\n",
576
+ "Episode 500\tAverage Score: 497.42\n",
577
+ "Episode 500\tAverage Score: 497.42\n",
578
+ "Episode 500\tAverage Score: 497.42\n",
579
+ "Episode 500\tAverage Score: 497.42\n",
580
+ "Episode 500\tAverage Score: 497.42\n",
581
+ "Episode 500\tAverage Score: 497.42\n",
582
+ "Episode 500\tAverage Score: 497.42\n",
583
+ "Episode 500\tAverage Score: 497.42\n",
584
+ "Episode 500\tAverage Score: 497.42\n",
585
+ "Episode 500\tAverage Score: 497.28\n",
586
+ "Episode 500\tAverage Score: 497.28\n",
587
+ "Episode 500\tAverage Score: 497.28\n",
588
+ "Episode 500\tAverage Score: 497.28\n",
589
+ "Episode 500\tAverage Score: 497.28\n",
590
+ "Episode 500\tAverage Score: 497.28\n",
591
+ "Episode 500\tAverage Score: 493.12\n",
592
+ "Episode 500\tAverage Score: 493.12\n",
593
+ "Episode 500\tAverage Score: 493.12\n",
594
+ "Episode 500\tAverage Score: 493.12\n",
595
+ "Episode 500\tAverage Score: 488.95\n",
596
+ "Episode 500\tAverage Score: 488.95\n",
597
+ "Episode 500\tAverage Score: 488.95\n",
598
+ "Episode 500\tAverage Score: 488.95\n",
599
+ "Episode 500\tAverage Score: 488.95\n",
600
+ "Episode 500\tAverage Score: 488.95\n",
601
+ "Episode 500\tAverage Score: 484.67\n",
602
+ "Episode 500\tAverage Score: 484.67\n",
603
+ "Episode 500\tAverage Score: 484.67\n",
604
+ "Episode 500\tAverage Score: 484.67\n",
605
+ "Episode 500\tAverage Score: 480.52\n",
606
+ "Episode 500\tAverage Score: 480.52\n",
607
+ "Episode 500\tAverage Score: 480.52\n",
608
+ "Episode 500\tAverage Score: 480.52\n",
609
+ "Episode 500\tAverage Score: 480.52\n",
610
+ "Episode 500\tAverage Score: 480.52\n",
611
+ "Episode 500\tAverage Score: 480.52\n",
612
+ "Episode 500\tAverage Score: 480.52\n",
613
+ "Episode 500\tAverage Score: 480.52\n",
614
+ "Episode 500\tAverage Score: 480.52\n",
615
+ "Episode 500\tAverage Score: 480.52\n",
616
+ "Episode 500\tAverage Score: 480.52\n",
617
+ "Episode 500\tAverage Score: 480.52\n",
618
+ "Episode 500\tAverage Score: 476.39\n",
619
+ "Episode 500\tAverage Score: 476.39\n",
620
+ "Episode 500\tAverage Score: 476.39\n",
621
+ "Episode 500\tAverage Score: 476.39\n",
622
+ "Episode 500\tAverage Score: 476.39\n",
623
+ "Episode 500\tAverage Score: 476.39\n",
624
+ "Episode 500\tAverage Score: 476.39\n",
625
+ "Episode 500\tAverage Score: 476.39\n",
626
+ "Episode 500\tAverage Score: 476.39\n",
627
+ "Episode 500\tAverage Score: 476.39\n",
628
+ "Episode 500\tAverage Score: 476.39\n",
629
+ "Episode 500\tAverage Score: 476.39\n",
630
+ "Episode 500\tAverage Score: 476.39\n",
631
+ "Episode 500\tAverage Score: 476.39\n",
632
+ "Episode 500\tAverage Score: 476.39\n",
633
+ "Episode 500\tAverage Score: 476.39\n",
634
+ "Episode 500\tAverage Score: 476.39\n",
635
+ "Episode 500\tAverage Score: 476.39\n",
636
+ "Episode 500\tAverage Score: 476.39\n",
637
+ "Episode 500\tAverage Score: 476.39\n",
638
+ "Episode 500\tAverage Score: 476.39\n",
639
+ "Episode 500\tAverage Score: 476.39\n",
640
+ "Episode 500\tAverage Score: 476.39\n",
641
+ "Episode 500\tAverage Score: 476.39\n",
642
+ "Episode 500\tAverage Score: 476.39\n",
643
+ "Episode 500\tAverage Score: 476.39\n",
644
+ "Episode 500\tAverage Score: 476.39\n",
645
+ "Episode 500\tAverage Score: 478.97\n",
646
+ "Episode 500\tAverage Score: 478.97\n",
647
+ "Episode 500\tAverage Score: 478.97\n",
648
+ "Episode 500\tAverage Score: 478.97\n",
649
+ "Episode 500\tAverage Score: 478.97\n",
650
+ "Episode 500\tAverage Score: 478.97\n",
651
+ "Episode 500\tAverage Score: 478.97\n",
652
+ "Episode 500\tAverage Score: 478.97\n",
653
+ "Episode 500\tAverage Score: 478.97\n",
654
+ "Episode 500\tAverage Score: 478.97\n",
655
+ "Episode 500\tAverage Score: 478.97\n",
656
+ "Episode 500\tAverage Score: 478.97\n",
657
+ "Episode 500\tAverage Score: 478.97\n",
658
+ "Episode 500\tAverage Score: 478.97\n",
659
+ "Episode 500\tAverage Score: 478.97\n",
660
+ "Episode 500\tAverage Score: 478.97\n",
661
+ "Episode 500\tAverage Score: 478.97\n",
662
+ "Episode 500\tAverage Score: 478.97\n",
663
+ "Episode 500\tAverage Score: 478.97\n",
664
+ "Episode 500\tAverage Score: 478.97\n",
665
+ "Episode 500\tAverage Score: 478.97\n",
666
+ "Episode 500\tAverage Score: 478.97\n",
667
+ "Episode 500\tAverage Score: 478.97\n",
668
+ "Episode 500\tAverage Score: 478.97\n",
669
+ "Episode 500\tAverage Score: 478.97\n",
670
+ "Episode 500\tAverage Score: 478.97\n",
671
+ "Episode 500\tAverage Score: 478.97\n",
672
+ "Episode 500\tAverage Score: 478.97\n",
673
+ "Episode 500\tAverage Score: 478.97\n",
674
+ "Episode 500\tAverage Score: 478.97\n",
675
+ "Episode 500\tAverage Score: 478.97\n",
676
+ "Episode 500\tAverage Score: 478.97\n",
677
+ "Episode 500\tAverage Score: 478.97\n",
678
+ "Episode 500\tAverage Score: 478.97\n",
679
+ "Episode 500\tAverage Score: 479.11\n",
680
+ "Episode 500\tAverage Score: 479.11\n",
681
+ "Episode 500\tAverage Score: 479.11\n",
682
+ "Episode 500\tAverage Score: 479.11\n",
683
+ "Episode 500\tAverage Score: 479.11\n",
684
+ "Episode 500\tAverage Score: 479.11\n",
685
+ "Episode 500\tAverage Score: 479.11\n",
686
+ "Episode 500\tAverage Score: 483.27\n",
687
+ "Episode 500\tAverage Score: 483.27\n",
688
+ "Episode 500\tAverage Score: 483.27\n",
689
+ "Episode 500\tAverage Score: 483.27\n",
690
+ "Episode 500\tAverage Score: 483.27\n",
691
+ "Episode 500\tAverage Score: 487.44\n",
692
+ "Episode 500\tAverage Score: 487.44\n",
693
+ "Episode 500\tAverage Score: 487.44\n",
694
+ "Episode 500\tAverage Score: 487.44\n",
695
+ "Episode 500\tAverage Score: 487.44\n",
696
+ "Episode 300\tAverage Score: 485.44\n",
697
+ "Episode 500\tAverage Score: 485.44\n",
698
+ "Episode 500\tAverage Score: 489.72\n",
699
+ "Episode 500\tAverage Score: 489.72\n",
700
+ "Episode 500\tAverage Score: 489.72\n",
701
+ "Episode 500\tAverage Score: 489.72\n",
702
+ "Episode 500\tAverage Score: 489.72\n",
703
+ "Episode 500\tAverage Score: 493.87\n",
704
+ "Episode 500\tAverage Score: 493.87\n",
705
+ "Episode 500\tAverage Score: 493.87\n",
706
+ "Episode 500\tAverage Score: 493.87\n",
707
+ "Episode 500\tAverage Score: 493.87\n",
708
+ "Episode 500\tAverage Score: 493.87\n",
709
+ "Episode 500\tAverage Score: 493.87\n",
710
+ "Episode 500\tAverage Score: 493.87\n",
711
+ "Episode 500\tAverage Score: 493.87\n",
712
+ "Episode 500\tAverage Score: 493.87\n",
713
+ "Episode 500\tAverage Score: 493.87\n",
714
+ "Episode 500\tAverage Score: 493.87\n",
715
+ "Episode 500\tAverage Score: 493.87\n",
716
+ "Episode 500\tAverage Score: 493.87\n",
717
+ "Episode 500\tAverage Score: 498.00\n",
718
+ "Episode 500\tAverage Score: 498.00\n",
719
+ "Episode 500\tAverage Score: 498.00\n",
720
+ "Episode 500\tAverage Score: 498.00\n",
721
+ "Episode 500\tAverage Score: 498.00\n",
722
+ "Episode 500\tAverage Score: 498.00\n",
723
+ "Episode 500\tAverage Score: 498.00\n",
724
+ "Episode 500\tAverage Score: 498.00\n",
725
+ "Episode 500\tAverage Score: 498.00\n",
726
+ "Episode 500\tAverage Score: 498.00\n",
727
+ "Episode 500\tAverage Score: 498.00\n",
728
+ "Episode 500\tAverage Score: 498.00\n",
729
+ "Episode 500\tAverage Score: 498.00\n",
730
+ "Episode 500\tAverage Score: 498.00\n",
731
+ "Episode 500\tAverage Score: 498.00\n",
732
+ "Episode 500\tAverage Score: 498.00\n",
733
+ "Episode 500\tAverage Score: 498.00\n",
734
+ "Episode 500\tAverage Score: 498.00\n",
735
+ "Episode 500\tAverage Score: 498.00\n",
736
+ "Episode 500\tAverage Score: 498.00\n",
737
+ "Episode 500\tAverage Score: 498.00\n",
738
+ "Episode 500\tAverage Score: 498.00\n",
739
+ "Episode 500\tAverage Score: 498.00\n",
740
+ "Episode 500\tAverage Score: 498.00\n",
741
+ "Episode 500\tAverage Score: 498.00\n",
742
+ "Episode 500\tAverage Score: 498.00\n",
743
+ "Episode 500\tAverage Score: 498.00\n",
744
+ "Episode 500\tAverage Score: 498.00\n",
745
+ "Episode 500\tAverage Score: 498.00\n",
746
+ "Episode 500\tAverage Score: 498.00\n",
747
+ "Episode 500\tAverage Score: 498.00\n",
748
+ "Episode 500\tAverage Score: 498.00\n",
749
+ "Episode 500\tAverage Score: 498.00\n",
750
+ "Episode 500\tAverage Score: 498.00\n",
751
+ "Episode 500\tAverage Score: 498.00\n",
752
+ "Episode 500\tAverage Score: 498.00\n",
753
+ "Episode 500\tAverage Score: 498.00\n",
754
+ "Episode 500\tAverage Score: 498.00\n",
755
+ "Episode 500\tAverage Score: 498.00\n",
756
+ "Episode 500\tAverage Score: 498.00\n",
757
+ "Episode 500\tAverage Score: 498.00\n",
758
+ "Episode 500\tAverage Score: 498.00\n",
759
+ "Episode 500\tAverage Score: 498.00\n",
760
+ "Episode 500\tAverage Score: 498.00\n",
761
+ "Episode 500\tAverage Score: 498.00\n",
762
+ "Episode 500\tAverage Score: 498.00\n",
763
+ "Episode 500\tAverage Score: 498.00\n",
764
+ "Episode 500\tAverage Score: 493.24\n",
765
+ "Episode 500\tAverage Score: 493.24\n",
766
+ "Episode 500\tAverage Score: 493.24\n",
767
+ "Episode 500\tAverage Score: 493.24\n",
768
+ "Episode 500\tAverage Score: 493.24\n",
769
+ "Episode 500\tAverage Score: 493.24\n",
770
+ "Episode 500\tAverage Score: 493.24\n",
771
+ "Episode 500\tAverage Score: 493.24\n",
772
+ "Episode 500\tAverage Score: 493.24\n",
773
+ "Episode 500\tAverage Score: 493.24\n",
774
+ "Episode 500\tAverage Score: 493.24\n",
775
+ "Episode 500\tAverage Score: 493.24\n",
776
+ "Episode 500\tAverage Score: 493.24\n",
777
+ "Episode 500\tAverage Score: 493.24\n",
778
+ "Episode 500\tAverage Score: 493.24\n",
779
+ "Episode 500\tAverage Score: 493.24\n",
780
+ "Episode 500\tAverage Score: 493.24\n",
781
+ "Episode 500\tAverage Score: 493.24\n",
782
+ "Episode 500\tAverage Score: 493.24\n",
783
+ "Episode 500\tAverage Score: 493.24\n",
784
+ "Episode 500\tAverage Score: 488.47\n",
785
+ "Episode 500\tAverage Score: 483.65\n",
786
+ "Episode 500\tAverage Score: 483.65\n",
787
+ "Episode 500\tAverage Score: 483.65\n",
788
+ "Episode 500\tAverage Score: 483.65\n",
789
+ "Episode 500\tAverage Score: 466.97\n",
790
+ "Episode 500\tAverage Score: 460.99\n",
791
+ "Episode 500\tAverage Score: 460.99\n",
792
+ "Episode 500\tAverage Score: 460.99\n",
793
+ "Episode 500\tAverage Score: 456.25\n",
794
+ "Episode 500\tAverage Score: 456.25\n",
795
+ "Episode 500\tAverage Score: 451.43\n",
796
+ "Episode 500\tAverage Score: 451.43\n",
797
+ "Episode 500\tAverage Score: 451.43\n",
798
+ "Episode 500\tAverage Score: 451.43\n",
799
+ "Episode 500\tAverage Score: 451.43\n",
800
+ "Episode 500\tAverage Score: 451.43\n",
801
+ "Episode 500\tAverage Score: 451.43\n",
802
+ "Episode 500\tAverage Score: 451.43\n",
803
+ "Episode 500\tAverage Score: 451.43\n",
804
+ "Episode 500\tAverage Score: 451.43\n",
805
+ "Episode 500\tAverage Score: 451.43\n",
806
+ "Episode 500\tAverage Score: 451.43\n",
807
+ "Episode 500\tAverage Score: 451.43\n",
808
+ "Episode 500\tAverage Score: 451.43\n",
809
+ "Episode 500\tAverage Score: 451.43\n",
810
+ "Episode 500\tAverage Score: 451.43\n",
811
+ "Episode 500\tAverage Score: 451.43\n",
812
+ "Episode 500\tAverage Score: 451.43\n",
813
+ "Episode 500\tAverage Score: 451.43\n",
814
+ "Episode 500\tAverage Score: 451.43\n",
815
+ "Episode 500\tAverage Score: 451.43\n",
816
+ "Episode 500\tAverage Score: 451.43\n",
817
+ "Episode 500\tAverage Score: 451.43\n",
818
+ "Episode 500\tAverage Score: 451.43\n",
819
+ "Episode 500\tAverage Score: 451.43\n",
820
+ "Episode 500\tAverage Score: 451.43\n",
821
+ "Episode 500\tAverage Score: 451.43\n",
822
+ "Episode 500\tAverage Score: 451.43\n",
823
+ "Episode 500\tAverage Score: 451.43\n",
824
+ "Episode 500\tAverage Score: 451.43\n",
825
+ "Episode 500\tAverage Score: 451.43\n",
826
+ "Episode 500\tAverage Score: 451.43\n",
827
+ "Episode 500\tAverage Score: 451.43\n",
828
+ "Episode 500\tAverage Score: 451.43\n",
829
+ "Episode 500\tAverage Score: 451.43\n",
830
+ "Episode 500\tAverage Score: 451.43\n",
831
+ "Episode 500\tAverage Score: 451.43\n",
832
+ "Episode 500\tAverage Score: 451.43\n",
833
+ "Episode 500\tAverage Score: 451.43\n",
834
+ "Episode 500\tAverage Score: 451.43\n",
835
+ "Episode 200\tAverage Score: 148.79\n",
836
+ "Episode 200\tAverage Score: 157.96\n",
837
+ "Episode 500\tAverage Score: 190.64\n",
838
+ "Episode 500\tAverage Score: 194.26\n",
839
+ "Episode 500\tAverage Score: 197.86\n",
840
+ "Episode 500\tAverage Score: 201.48\n",
841
+ "Episode 500\tAverage Score: 205.15\n",
842
+ "Episode 500\tAverage Score: 208.76\n",
843
+ "Episode 500\tAverage Score: 212.41\n",
844
+ "Episode 500\tAverage Score: 216.13\n",
845
+ "Episode 500\tAverage Score: 219.72\n",
846
+ "Episode 500\tAverage Score: 223.56\n",
847
+ "Episode 500\tAverage Score: 227.23\n",
848
+ "Episode 500\tAverage Score: 230.90\n",
849
+ "Episode 500\tAverage Score: 234.61\n",
850
+ "Episode 500\tAverage Score: 238.32\n",
851
+ "Episode 500\tAverage Score: 241.99\n",
852
+ "Episode 500\tAverage Score: 245.78\n",
853
+ "Episode 500\tAverage Score: 249.43\n",
854
+ "Episode 500\tAverage Score: 253.18\n",
855
+ "Episode 500\tAverage Score: 256.85\n",
856
+ "Episode 500\tAverage Score: 260.43\n",
857
+ "Episode 500\tAverage Score: 263.94\n",
858
+ "Episode 500\tAverage Score: 267.68\n",
859
+ "Episode 500\tAverage Score: 271.27\n",
860
+ "Episode 500\tAverage Score: 274.87\n",
861
+ "Episode 500\tAverage Score: 278.51\n",
862
+ "Episode 500\tAverage Score: 282.18\n",
863
+ "Episode 500\tAverage Score: 285.67\n",
864
+ "Episode 500\tAverage Score: 289.04\n",
865
+ "Episode 500\tAverage Score: 292.48\n",
866
+ "Episode 500\tAverage Score: 295.88\n",
867
+ "Episode 500\tAverage Score: 299.61\n",
868
+ "Episode 500\tAverage Score: 302.84\n",
869
+ "Episode 500\tAverage Score: 305.97\n",
870
+ "Episode 500\tAverage Score: 309.13\n",
871
+ "Episode 500\tAverage Score: 312.46\n",
872
+ "Episode 500\tAverage Score: 315.80\n",
873
+ "Episode 500\tAverage Score: 319.12\n",
874
+ "Episode 500\tAverage Score: 321.31\n",
875
+ "Episode 500\tAverage Score: 324.54\n",
876
+ "Episode 500\tAverage Score: 327.67\n",
877
+ "Episode 500\tAverage Score: 330.83\n",
878
+ "Episode 500\tAverage Score: 333.27\n",
879
+ "Episode 500\tAverage Score: 336.25\n",
880
+ "Episode 500\tAverage Score: 339.31\n",
881
+ "Episode 500\tAverage Score: 342.54\n"
882
+ ]
883
+ }
884
+ ],
885
+ "source": [
886
+ "scores = reinforce(\n",
887
+ " cartpole_policy,\n",
888
+ " cartpole_optimizer,\n",
889
+ " cartpole_hyperparameters[\"n_training_episodes\"],\n",
890
+ " cartpole_hyperparameters[\"max_t\"],\n",
891
+ " cartpole_hyperparameters[\"gamma\"],\n",
892
+ " 100,\n",
893
+ ")"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": 19,
899
+ "metadata": {},
900
+ "outputs": [],
901
+ "source": [
902
+ "def evaluate_agent(env, max_steps, n_eval_episodes, policy):\n",
903
+ " \"\"\"\n",
904
+ " Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.\n",
905
+ " :param env: The evaluation environment\n",
906
+ " :param n_eval_episodes: Number of episode to evaluate the agent\n",
907
+ " :param policy: The Reinforce agent\n",
908
+ " \"\"\"\n",
909
+ " episode_rewards = []\n",
910
+ " for episode in range(n_eval_episodes):\n",
911
+ " state = env.reset()\n",
912
+ " step = 0\n",
913
+ " done = False\n",
914
+ " total_rewards_ep = 0\n",
915
+ "\n",
916
+ " for step in range(max_steps):\n",
917
+ " action, _ = policy.act(state)\n",
918
+ " new_state, reward, done, info = env.step(action)\n",
919
+ " total_rewards_ep += reward\n",
920
+ "\n",
921
+ " if done:\n",
922
+ " break\n",
923
+ " state = new_state\n",
924
+ " episode_rewards.append(total_rewards_ep)\n",
925
+ " mean_reward = np.mean(episode_rewards)\n",
926
+ " std_reward = np.std(episode_rewards)\n",
927
+ "\n",
928
+ " return mean_reward, std_reward"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "execution_count": 35,
934
+ "metadata": {},
935
+ "outputs": [
936
+ {
937
+ "data": {
938
+ "text/plain": [
939
+ "(448.7, 65.16141496315132)"
940
+ ]
941
+ },
942
+ "execution_count": 35,
943
+ "metadata": {},
944
+ "output_type": "execute_result"
945
+ }
946
+ ],
947
+ "source": [
948
+ "evaluate_agent(\n",
949
+ " eval_env, cartpole_hyperparameters[\"max_t\"], cartpole_hyperparameters[\"n_evaluation_episodes\"], cartpole_policy\n",
950
+ ")"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": 21,
956
+ "metadata": {},
957
+ "outputs": [],
958
+ "source": [
959
+ "from huggingface_hub import HfApi, snapshot_download\n",
960
+ "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n",
961
+ "\n",
962
+ "from pathlib import Path\n",
963
+ "import datetime\n",
964
+ "import json\n",
965
+ "import imageio\n",
966
+ "\n",
967
+ "import tempfile\n",
968
+ "\n",
969
+ "import os"
970
+ ]
971
+ },
972
+ {
973
+ "cell_type": "code",
974
+ "execution_count": 22,
975
+ "metadata": {},
976
+ "outputs": [],
977
+ "source": [
978
+ "def record_video(env, policy, out_directory, fps=30):\n",
979
+ " \"\"\"\n",
980
+ " Generate a replay video of the agent\n",
981
+ " :param env\n",
982
+ " :param Qtable: Qtable of our agent\n",
983
+ " :param out_directory\n",
984
+ " :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1)\n",
985
+ " \"\"\"\n",
986
+ " images = []\n",
987
+ " done = False\n",
988
+ " state = env.reset()\n",
989
+ " img = env.render(mode=\"rgb_array\")\n",
990
+ " images.append(img)\n",
991
+ " while not done:\n",
992
+ " # Take the action (index) that have the maximum expected future reward given that state\n",
993
+ " action, _ = policy.act(state)\n",
994
+ " state, reward, done, info = env.step(action) # We directly put next_state = state for recording logic\n",
995
+ " img = env.render(mode=\"rgb_array\")\n",
996
+ " images.append(img)\n",
997
+ " imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)"
998
+ ]
999
+ },
1000
+ {
1001
+ "cell_type": "code",
1002
+ "execution_count": 23,
1003
+ "metadata": {},
1004
+ "outputs": [],
1005
+ "source": [
1006
+ "from huggingface_hub import HfApi, snapshot_download\n",
1007
+ "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n",
1008
+ "\n",
1009
+ "from pathlib import Path\n",
1010
+ "import datetime\n",
1011
+ "import json\n",
1012
+ "import imageio\n",
1013
+ "\n",
1014
+ "import tempfile\n",
1015
+ "\n",
1016
+ "import os"
1017
+ ]
1018
+ },
1019
+ {
1020
+ "cell_type": "code",
1021
+ "execution_count": 29,
1022
+ "metadata": {},
1023
+ "outputs": [],
1024
+ "source": [
1025
+ "def push_to_hub(repo_id,\n",
1026
+ " model,\n",
1027
+ " hyperparameters,\n",
1028
+ " eval_env,\n",
1029
+ " video_fps=30\n",
1030
+ " ):\n",
1031
+ " \"\"\"\n",
1032
+ " Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n",
1033
+ " This method does the complete pipeline:\n",
1034
+ " - It evaluates the model\n",
1035
+ " - It generates the model card\n",
1036
+ " - It generates a replay video of the agent\n",
1037
+ " - It pushes everything to the Hub\n",
1038
+ "\n",
1039
+ " :param repo_id: repo_id: id of the model repository from the Hugging Face Hub\n",
1040
+ " :param model: the pytorch model we want to save\n",
1041
+ " :param hyperparameters: training hyperparameters\n",
1042
+ " :param eval_env: evaluation environment\n",
1043
+ " :param video_fps: how many frame per seconds to record our video replay\n",
1044
+ " \"\"\"\n",
1045
+ "\n",
1046
+ " _, repo_name = repo_id.split(\"/\")\n",
1047
+ " api = HfApi()\n",
1048
+ "\n",
1049
+ " # Step 1: Create the repo\n",
1050
+ " repo_url = api.create_repo(\n",
1051
+ " repo_id=repo_id,\n",
1052
+ " exist_ok=True,\n",
1053
+ " )\n",
1054
+ "\n",
1055
+ " with tempfile.TemporaryDirectory() as tmpdirname:\n",
1056
+ " local_directory = Path(\"./\")\n",
1057
+ "\n",
1058
+ " # Step 2: Save the model\n",
1059
+ " torch.save(model, local_directory / \"model.pt\")\n",
1060
+ "\n",
1061
+ " # Step 3: Save the hyperparameters to JSON\n",
1062
+ " with open(local_directory / \"hyperparameters.json\", \"w\") as outfile:\n",
1063
+ " json.dump(hyperparameters, outfile)\n",
1064
+ "\n",
1065
+ " # Step 4: Evaluate the model and build JSON\n",
1066
+ " mean_reward, std_reward = evaluate_agent(eval_env,\n",
1067
+ " hyperparameters[\"max_t\"],\n",
1068
+ " hyperparameters[\"n_evaluation_episodes\"],\n",
1069
+ " model)\n",
1070
+ " # Get datetime\n",
1071
+ " eval_datetime = datetime.datetime.now()\n",
1072
+ " eval_form_datetime = eval_datetime.isoformat()\n",
1073
+ "\n",
1074
+ " evaluate_data = {\n",
1075
+ " \"env_id\": hyperparameters[\"env_id\"],\n",
1076
+ " \"mean_reward\": mean_reward,\n",
1077
+ " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
1078
+ " \"eval_datetime\": eval_form_datetime,\n",
1079
+ " }\n",
1080
+ "\n",
1081
+ " # Write a JSON file\n",
1082
+ " with open(local_directory / \"results.json\", \"w\") as outfile:\n",
1083
+ " json.dump(evaluate_data, outfile)\n",
1084
+ "\n",
1085
+ " # Step 5: Create the model card\n",
1086
+ " env_name = hyperparameters[\"env_id\"]\n",
1087
+ "\n",
1088
+ " metadata = {}\n",
1089
+ " metadata[\"tags\"] = [\n",
1090
+ " env_name,\n",
1091
+ " \"reinforce\",\n",
1092
+ " \"reinforcement-learning\",\n",
1093
+ " \"custom-implementation\",\n",
1094
+ " \"deep-rl-class\"\n",
1095
+ " ]\n",
1096
+ "\n",
1097
+ " # Add metrics\n",
1098
+ " eval = metadata_eval_result(\n",
1099
+ " model_pretty_name=repo_name,\n",
1100
+ " task_pretty_name=\"reinforcement-learning\",\n",
1101
+ " task_id=\"reinforcement-learning\",\n",
1102
+ " metrics_pretty_name=\"mean_reward\",\n",
1103
+ " metrics_id=\"mean_reward\",\n",
1104
+ " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
1105
+ " dataset_pretty_name=env_name,\n",
1106
+ " dataset_id=env_name,\n",
1107
+ " )\n",
1108
+ "\n",
1109
+ " # Merges both dictionaries\n",
1110
+ " metadata = {**metadata, **eval}\n",
1111
+ "\n",
1112
+ " model_card = f\"\"\"\n",
1113
+ " # **Reinforce** Agent playing **{env_id}**\n",
1114
+ " This is a trained model of a **Reinforce** agent playing **{env_id}** .\n",
1115
+ " To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n",
1116
+ " \"\"\"\n",
1117
+ "\n",
1118
+ " readme_path = local_directory / \"README.md\"\n",
1119
+ " readme = \"\"\n",
1120
+ " if readme_path.exists():\n",
1121
+ " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
1122
+ " readme = f.read()\n",
1123
+ " else:\n",
1124
+ " readme = model_card\n",
1125
+ "\n",
1126
+ " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
1127
+ " f.write(readme)\n",
1128
+ "\n",
1129
+ " # Save our metrics to Readme metadata\n",
1130
+ " metadata_save(readme_path, metadata)\n",
1131
+ "\n",
1132
+ " # Step 6: Record a video\n",
1133
+ " video_path = local_directory / \"replay.mp4\"\n",
1134
+ " record_video(env, model, video_path, video_fps)\n",
1135
+ "\n",
1136
+ " # Step 7. Push everything to the Hub\n",
1137
+ " api.upload_folder(\n",
1138
+ " repo_id=repo_id,\n",
1139
+ " folder_path=local_directory,\n",
1140
+ " path_in_repo=\".\",\n",
1141
+ " )\n",
1142
+ "\n",
1143
+ " print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "cell_type": "code",
1148
+ "execution_count": 31,
1149
+ "metadata": {},
1150
+ "outputs": [
1151
+ {
1152
+ "name": "stdout",
1153
+ "output_type": "stream",
1154
+ "text": [
1155
+ "Token is valid.\n",
1156
+ "Your token has been saved in your configured git credential helpers (store).\n",
1157
+ "Your token has been saved to /home/hanbk/.cache/huggingface/token\n",
1158
+ "Login successful\n"
1159
+ ]
1160
+ }
1161
+ ],
1162
+ "source": [
1163
+ "notebook_login()"
1164
+ ]
1165
+ },
1166
+ {
1167
+ "cell_type": "code",
1168
+ "execution_count": 36,
1169
+ "metadata": {},
1170
+ "outputs": [
1171
+ {
1172
+ "name": "stderr",
1173
+ "output_type": "stream",
1174
+ "text": [
1175
+ "IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (600, 400) to (608, 400) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).\n",
1176
+ "[swscaler @ 0x7313080] Warning: data is not aligned! This can lead to a speed loss\n"
1177
+ ]
1178
+ },
1179
+ {
1180
+ "name": "stdout",
1181
+ "output_type": "stream",
1182
+ "text": [
1183
+ "Your model is pushed to the Hub. You can view your model here: https://huggingface.co/bkhan2000/Reinforce-CartPole-v1\n"
1184
+ ]
1185
+ }
1186
+ ],
1187
+ "source": [
1188
+ "repo_id = f\"bkhan2000/Reinforce-{env_id}\" # TODO Define your repo id {username/Reinforce-{model-id}}\n",
1189
+ "push_to_hub(\n",
1190
+ " repo_id,\n",
1191
+ " cartpole_policy, # The model we want to save\n",
1192
+ " cartpole_hyperparameters, # Hyperparameters\n",
1193
+ " eval_env, # Evaluation environment\n",
1194
+ " video_fps=30\n",
1195
+ ")"
1196
+ ]
1197
+ },
1198
+ {
1199
+ "cell_type": "markdown",
1200
+ "metadata": {},
1201
+ "source": [
1202
+ "### PixelCopter"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "cell_type": "code",
1207
+ "execution_count": 37,
1208
+ "metadata": {},
1209
+ "outputs": [
1210
+ {
1211
+ "name": "stdout",
1212
+ "output_type": "stream",
1213
+ "text": [
1214
+ "pygame 2.1.3 (SDL 2.0.22, Python 3.8.10)\n",
1215
+ "Hello from the pygame community. https://www.pygame.org/contribute.html\n",
1216
+ "couldn't import doomish\n",
1217
+ "Couldn't import doom\n"
1218
+ ]
1219
+ }
1220
+ ],
1221
+ "source": [
1222
+ "env_id = \"Pixelcopter-PLE-v0\"\n",
1223
+ "env = gym.make(env_id)\n",
1224
+ "eval_env = gym.make(env_id)\n",
1225
+ "s_size = env.observation_space.shape[0]\n",
1226
+ "a_size = env.action_space.n"
1227
+ ]
1228
+ },
1229
+ {
1230
+ "cell_type": "code",
1231
+ "execution_count": 38,
1232
+ "metadata": {},
1233
+ "outputs": [
1234
+ {
1235
+ "name": "stdout",
1236
+ "output_type": "stream",
1237
+ "text": [
1238
+ "_____OBSERVATION SPACE_____ \n",
1239
+ "\n",
1240
+ "The State Space is: 7\n",
1241
+ "Sample observation [ 0.9645765 -1.6262507 0.25693664 0.18892749 2.2655454 0.37077877\n",
1242
+ " 1.3749579 ]\n"
1243
+ ]
1244
+ }
1245
+ ],
1246
+ "source": [
1247
+ "print(\"_____OBSERVATION SPACE_____ \\n\")\n",
1248
+ "print(\"The State Space is: \", s_size)\n",
1249
+ "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
1250
+ ]
1251
+ },
1252
+ {
1253
+ "cell_type": "code",
1254
+ "execution_count": 39,
1255
+ "metadata": {},
1256
+ "outputs": [
1257
+ {
1258
+ "name": "stdout",
1259
+ "output_type": "stream",
1260
+ "text": [
1261
+ "\n",
1262
+ " _____ACTION SPACE_____ \n",
1263
+ "\n",
1264
+ "The Action Space is: 2\n",
1265
+ "Action Space Sample 0\n"
1266
+ ]
1267
+ }
1268
+ ],
1269
+ "source": [
1270
+ "print(\"\\n _____ACTION SPACE_____ \\n\")\n",
1271
+ "print(\"The Action Space is: \", a_size)\n",
1272
+ "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
1273
+ ]
1274
+ },
1275
+ {
1276
+ "cell_type": "code",
1277
+ "execution_count": 40,
1278
+ "metadata": {},
1279
+ "outputs": [],
1280
+ "source": [
1281
+ "class Policy(nn.Module):\n",
1282
+ " def __init__(self, s_size, a_size, h_size):\n",
1283
+ " super(Policy, self).__init__()\n",
1284
+ " self.fc1 = nn.Linear(s_size, h_size)\n",
1285
+ " self.fc2 = nn.Linear(h_size, h_size * 2)\n",
1286
+ " self.fc3 = nn.Linear(h_size * 2, a_size)\n",
1287
+ "\n",
1288
+ " def forward(self, x):\n",
1289
+ " x = F.relu(self.fc1(x))\n",
1290
+ " x = F.relu(self.fc2(x))\n",
1291
+ " x = self.fc3(x)\n",
1292
+ " return F.softmax(x, dim=1)\n",
1293
+ "\n",
1294
+ " def act(self, state):\n",
1295
+ " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n",
1296
+ " probs = self.forward(state).cpu()\n",
1297
+ " m = Categorical(probs)\n",
1298
+ " action = m.sample()\n",
1299
+ " return action.item(), m.log_prob(action)"
1300
+ ]
1301
+ },
1302
+ {
1303
+ "cell_type": "code",
1304
+ "execution_count": 41,
1305
+ "metadata": {},
1306
+ "outputs": [],
1307
+ "source": [
1308
+ "pixelcopter_hyperparameters = {\n",
1309
+ " \"h_size\": 64,\n",
1310
+ " \"n_training_episodes\": 50000,\n",
1311
+ " \"n_evaluation_episodes\": 10,\n",
1312
+ " \"max_t\": 10000,\n",
1313
+ " \"gamma\": 0.99,\n",
1314
+ " \"lr\": 1e-4,\n",
1315
+ " \"env_id\": env_id,\n",
1316
+ " \"state_space\": s_size,\n",
1317
+ " \"action_space\": a_size,\n",
1318
+ "}"
1319
+ ]
1320
+ },
1321
+ {
1322
+ "cell_type": "code",
1323
+ "execution_count": 42,
1324
+ "metadata": {},
1325
+ "outputs": [],
1326
+ "source": [
1327
+ "pixelcopter_policy = Policy(\n",
1328
+ " pixelcopter_hyperparameters[\"state_space\"],\n",
1329
+ " pixelcopter_hyperparameters[\"action_space\"],\n",
1330
+ " pixelcopter_hyperparameters[\"h_size\"],\n",
1331
+ ").to(device)\n",
1332
+ "pixelcopter_optimizer = optim.Adam(pixelcopter_policy.parameters(), lr=pixelcopter_hyperparameters[\"lr\"])"
1333
+ ]
1334
+ },
1335
+ {
1336
+ "cell_type": "code",
1337
+ "execution_count": 43,
1338
+ "metadata": {},
1339
+ "outputs": [],
1340
+ "source": [
1341
+ "scores = reinforce(\n",
1342
+ " pixelcopter_policy,\n",
1343
+ " pixelcopter_optimizer,\n",
1344
+ " pixelcopter_hyperparameters[\"n_training_episodes\"],\n",
1345
+ " pixelcopter_hyperparameters[\"max_t\"],\n",
1346
+ " pixelcopter_hyperparameters[\"gamma\"],\n",
1347
+ " 1000,\n",
1348
+ ")"
1349
+ ]
1350
+ },
1351
+ {
1352
+ "cell_type": "code",
1353
+ "execution_count": null,
1354
+ "metadata": {},
1355
+ "outputs": [],
1356
+ "source": [
1357
+ "repo_id = f\"bkhan2000/Reinforce-{env_id}\" # TODO Define your repo id {username/Reinforce-{model-id}}\n",
1358
+ "push_to_hub(\n",
1359
+ " repo_id,\n",
1360
+ " pixelcopter_policy, # The model we want to save\n",
1361
+ " pixelcopter_hyperparameters, # Hyperparameters\n",
1362
+ " eval_env, # Evaluation environment\n",
1363
+ " video_fps=30\n",
1364
+ ")"
1365
+ ]
1366
+ }
1367
+ ],
1368
+ "metadata": {
1369
+ "kernelspec": {
1370
+ "display_name": "Python 3.8.10 ('torch_venv')",
1371
+ "language": "python",
1372
+ "name": "python3"
1373
+ },
1374
+ "language_info": {
1375
+ "codemirror_mode": {
1376
+ "name": "ipython",
1377
+ "version": 3
1378
+ },
1379
+ "file_extension": ".py",
1380
+ "mimetype": "text/x-python",
1381
+ "name": "python",
1382
+ "nbconvert_exporter": "python",
1383
+ "pygments_lexer": "ipython3",
1384
+ "version": "3.8.10"
1385
+ },
1386
+ "orig_nbformat": 4,
1387
+ "vscode": {
1388
+ "interpreter": {
1389
+ "hash": "745a3b3e3fb7ac09f0ebb6d5eb47d006584e16db5d9df6f9a8b654baa561b29f"
1390
+ }
1391
+ }
1392
+ },
1393
+ "nbformat": 4,
1394
+ "nbformat_minor": 2
1395
+ }
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - Pixelcopter-PLE-v0
4
+ - reinforce
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ - deep-rl-class
8
+ model-index:
9
+ - name: Reinforce-Pixelcopter-PLE-v0
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: Pixelcopter-PLE-v0
16
+ type: Pixelcopter-PLE-v0
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 105.50 +/- 80.81
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # **Reinforce** Agent playing **CartPole-v1**
25
+ This is a trained model of a **Reinforce** agent playing **CartPole-v1** .
26
+ To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction
27
+
hyperparameters.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"h_size": 64, "n_training_episodes": 50000, "n_evaluation_episodes": 10, "max_t": 10000, "gamma": 0.99, "lr": 0.0001, "env_id": "Pixelcopter-PLE-v0", "state_space": 7, "action_space": 2}
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b840e9fd1147ee00e9fa15ea9e02b251f127c981ba5a16ed3f54d332f7146666
3
+ size 38999
replay.mp4 ADDED
Binary file (29.3 kB). View file
 
results.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"env_id": "Pixelcopter-PLE-v0", "mean_reward": 105.5, "n_evaluation_episodes": 10, "eval_datetime": "2023-03-06T14:58:27.599581"}