{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Virtual display\n", "from pyvirtualdisplay import Display\n", "\n", "virtual_display = Display(visible=0, size=(1400, 900))\n", "virtual_display.start()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from collections import deque\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "# PyTorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.distributions import Categorical\n", "\n", "# Gym\n", "import gym\n", "import gym_pygame\n", "\n", "# Hugging Face Hub\n", "from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n", "import imageio\n", "# imageio: A library that will help us to generate a replay video" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cartpole-v1" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "env_id = \"CartPole-v1\"\n", "env = gym.make(env_id)\n", "\n", "# evaluation env\n", "eval_env = gym.make(env_id)\n", "\n", "s_size = env.observation_space.shape[0]\n", "a_size = env.action_space.n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_____OBSERVATION SPACE_____ \n", "\n", "The State Space is: 4\n", "Sample observation [-2.6818509e+00 2.6710869e+38 -2.7456334e-01 4.6941264e+37]\n" ] } ], "source": [ "print(\"_____OBSERVATION SPACE_____ \\n\")\n", "print(\"The State Space is: \", s_size)\n", "print(\"Sample observation\", env.observation_space.sample()) # Get a random observation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " _____ACTION SPACE_____ \n", "\n", "The Action Space is: 2\n", "Action Space Sample 0\n" ] } ], "source": [ "print(\"\\n _____ACTION SPACE_____ \\n\")\n", "print(\"The Action Space is: \", a_size)\n", "print(\"Action Space Sample\", env.action_space.sample()) # Take a random action" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reinforce Archtecture" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Policy(nn.Module):\n", " def __init__(self, s_size, a_size, h_size):\n", " super(Policy, self).__init__()\n", " self.fc1 = nn.Linear(s_size, h_size)\n", " self.fc2 = nn.Linear(h_size, a_size)\n", " \n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return F.softmax(x, dim=1)\n", "\n", " def act(self, state):\n", " state = torch.from_numpy(state).float().unsqueeze(0).to(device)\n", " probs = self.forward(state).cpu()\n", " m = Categorical(probs)\n", " # action = np.argmax(m)\n", " action = m.sample()\n", " return action.item(), m.log_prob(action)\n", " " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, tensor([-0.7983], grad_fn=))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "debug_policy = Policy(s_size, a_size, 64).to(device)\n", "debug_policy.act(env.reset())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):\n", " scores_deque = deque(maxlen=100)\n", " scores = []\n", "\n", " # Line 3 of pseudocode\n", " for i_episodes in range(1, n_training_episodes+1):\n", " saved_log_probs = []\n", " rewards = []\n", " state = env.reset()\n", "\n", " # Line 4 of pseudocode\n", " for i_episode in range(1, n_training_episodes):\n", " action, log_prob = policy.act(state)\n", " saved_log_probs.append(log_prob)\n", " state, reward, done, _ = env.step(action)\n", " rewards.append(reward)\n", " if done:\n", " break\n", " scores_deque.append(sum(rewards))\n", " scores.append(sum(rewards))\n", "\n", " # Line 6 of pseudocode\n", " returns = deque(maxlen=max_t)\n", " n_steps = len(rewards)\n", "\n", " for t in range(n_steps)[::-1]:\n", " disc_return_t = (returns[0] if len(returns)>0 else 0)\n", " returns.appendleft(gamma * disc_return_t + rewards[t])\n", "\n", " eps = np.finfo(np.float32).eps.item()\n", "\n", " returns = torch.tensor(returns)\n", " returns = (returns - returns.mean()) / (returns.std() + eps)\n", "\n", " # Line 7\n", " policy_loss = []\n", " for log_prob, disc_return in zip(saved_log_probs, returns):\n", " policy_loss.append(-log_prob * disc_return)\n", " policy_loss = torch.cat(policy_loss).sum()\n", "\n", " # Line 8\n", " optimizer.zero_grad()\n", " policy_loss.backward()\n", " optimizer.step()\n", "\n", " if i_episode % print_every == 0:\n", " print(\"Episode {}\\tAverage Score: {:.2f}\".format(i_episode, np.mean(scores_deque)))\n", "\n", " return scores" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "cartpole_hyperparameters = {\n", " \"h_size\": 16,\n", " \"n_training_episodes\": 1000,\n", " \"n_evaluation_episodes\": 10,\n", " \"max_t\": 1000,\n", " \"gamma\": 1.0,\n", " \"lr\": 1e-2,\n", " \"env_id\": env_id,\n", " \"state_space\": s_size,\n", " \"action_space\": a_size,\n", "}" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "cartpole_policy = Policy(\n", " cartpole_hyperparameters[\"state_space\"],\n", " cartpole_hyperparameters[\"action_space\"],\n", " cartpole_hyperparameters[\"h_size\"],\n", ").to(device)\n", "cartpole_optimizer = optim.Adam(cartpole_policy.parameters(), lr=cartpole_hyperparameters[\"lr\"])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode 500\tAverage Score: 116.93\n", "Episode 500\tAverage Score: 134.13\n", "Episode 500\tAverage Score: 138.92\n", "Episode 500\tAverage Score: 143.73\n", "Episode 500\tAverage Score: 150.68\n", "Episode 500\tAverage Score: 154.91\n", "Episode 500\tAverage Score: 159.05\n", "Episode 500\tAverage Score: 163.41\n", "Episode 500\tAverage Score: 167.91\n", "Episode 500\tAverage Score: 172.49\n", "Episode 500\tAverage Score: 176.90\n", "Episode 500\tAverage Score: 181.63\n", "Episode 500\tAverage Score: 185.66\n", "Episode 500\tAverage Score: 190.18\n", "Episode 500\tAverage Score: 194.90\n", "Episode 500\tAverage Score: 199.15\n", "Episode 500\tAverage Score: 203.89\n", "Episode 500\tAverage Score: 208.33\n", "Episode 500\tAverage Score: 212.64\n", "Episode 500\tAverage Score: 217.48\n", "Episode 500\tAverage Score: 221.51\n", "Episode 500\tAverage Score: 226.20\n", "Episode 500\tAverage Score: 230.63\n", "Episode 500\tAverage Score: 235.21\n", "Episode 500\tAverage Score: 243.17\n", "Episode 500\tAverage Score: 250.87\n", "Episode 500\tAverage Score: 254.48\n", "Episode 500\tAverage Score: 258.01\n", "Episode 500\tAverage Score: 262.76\n", "Episode 500\tAverage Score: 267.27\n", "Episode 500\tAverage Score: 271.85\n", "Episode 500\tAverage Score: 275.57\n", "Episode 500\tAverage Score: 281.62\n", "Episode 500\tAverage Score: 284.87\n", "Episode 500\tAverage Score: 289.12\n", "Episode 500\tAverage Score: 295.51\n", "Episode 500\tAverage Score: 299.59\n", "Episode 500\tAverage Score: 303.39\n", "Episode 500\tAverage Score: 310.17\n", "Episode 500\tAverage Score: 313.95\n", "Episode 500\tAverage Score: 317.26\n", "Episode 500\tAverage Score: 318.30\n", "Episode 500\tAverage Score: 322.61\n", "Episode 500\tAverage Score: 327.74\n", "Episode 500\tAverage Score: 331.85\n", "Episode 500\tAverage Score: 335.04\n", "Episode 500\tAverage Score: 339.34\n", "Episode 500\tAverage Score: 343.40\n", "Episode 500\tAverage Score: 345.81\n", "Episode 500\tAverage Score: 348.98\n", "Episode 500\tAverage Score: 352.50\n", "Episode 500\tAverage Score: 356.47\n", "Episode 500\tAverage Score: 360.60\n", "Episode 500\tAverage Score: 364.78\n", "Episode 500\tAverage Score: 368.87\n", "Episode 500\tAverage Score: 372.04\n", "Episode 500\tAverage Score: 374.21\n", "Episode 500\tAverage Score: 376.52\n", "Episode 500\tAverage Score: 379.97\n", "Episode 500\tAverage Score: 382.65\n", "Episode 500\tAverage Score: 384.00\n", "Episode 500\tAverage Score: 386.29\n", "Episode 500\tAverage Score: 391.30\n", "Episode 500\tAverage Score: 394.40\n", "Episode 500\tAverage Score: 398.01\n", "Episode 500\tAverage Score: 400.75\n", "Episode 500\tAverage Score: 404.74\n", "Episode 500\tAverage Score: 408.86\n", "Episode 500\tAverage Score: 412.89\n", "Episode 500\tAverage Score: 417.54\n", "Episode 500\tAverage Score: 421.40\n", "Episode 500\tAverage Score: 425.71\n", "Episode 500\tAverage Score: 425.96\n", "Episode 500\tAverage Score: 430.19\n", "Episode 500\tAverage Score: 434.20\n", "Episode 500\tAverage Score: 434.40\n", "Episode 500\tAverage Score: 438.51\n", "Episode 500\tAverage Score: 441.44\n", "Episode 500\tAverage Score: 445.65\n", "Episode 500\tAverage Score: 448.57\n", "Episode 500\tAverage Score: 451.66\n", "Episode 500\tAverage Score: 455.92\n", "Episode 500\tAverage Score: 458.06\n", "Episode 500\tAverage Score: 460.77\n", "Episode 500\tAverage Score: 460.77\n", "Episode 500\tAverage Score: 462.53\n", "Episode 500\tAverage Score: 463.35\n", "Episode 500\tAverage Score: 465.71\n", "Episode 500\tAverage Score: 467.43\n", "Episode 500\tAverage Score: 471.61\n", "Episode 500\tAverage Score: 471.61\n", "Episode 500\tAverage Score: 471.61\n", "Episode 500\tAverage Score: 471.61\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 474.02\n", "Episode 500\tAverage Score: 469.92\n", "Episode 500\tAverage Score: 466.39\n", "Episode 500\tAverage Score: 470.74\n", "Episode 500\tAverage Score: 470.74\n", "Episode 500\tAverage Score: 472.07\n", "Episode 500\tAverage Score: 472.07\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 476.40\n", "Episode 500\tAverage Score: 479.20\n", "Episode 500\tAverage Score: 475.32\n", "Episode 500\tAverage Score: 472.31\n", "Episode 500\tAverage Score: 472.31\n", "Episode 500\tAverage Score: 472.31\n", "Episode 500\tAverage Score: 470.49\n", "Episode 500\tAverage Score: 470.49\n", "Episode 500\tAverage Score: 470.49\n", "Episode 500\tAverage Score: 470.49\n", "Episode 500\tAverage Score: 466.40\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 468.61\n", "Episode 500\tAverage Score: 472.51\n", "Episode 500\tAverage Score: 472.51\n", "Episode 500\tAverage Score: 472.51\n", "Episode 500\tAverage Score: 467.72\n", "Episode 500\tAverage Score: 467.72\n", "Episode 500\tAverage Score: 462.94\n", "Episode 500\tAverage Score: 462.94\n", "Episode 500\tAverage Score: 462.94\n", "Episode 500\tAverage Score: 462.94\n", "Episode 500\tAverage Score: 462.94\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 465.15\n", "Episode 500\tAverage Score: 469.25\n", "Episode 500\tAverage Score: 469.25\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 473.97\n", "Episode 500\tAverage Score: 477.85\n", "Episode 500\tAverage Score: 477.85\n", "Episode 500\tAverage Score: 482.59\n", "Episode 500\tAverage Score: 482.59\n", "Episode 500\tAverage Score: 482.59\n", "Episode 500\tAverage Score: 482.59\n", "Episode 500\tAverage Score: 486.34\n", "Episode 500\tAverage Score: 486.34\n", "Episode 500\tAverage Score: 486.34\n", "Episode 500\tAverage Score: 486.34\n", "Episode 500\tAverage Score: 486.34\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 490.43\n", "Episode 500\tAverage Score: 495.22\n", "Episode 500\tAverage Score: 495.22\n", "Episode 500\tAverage Score: 495.22\n", "Episode 500\tAverage Score: 500.00\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.42\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 497.28\n", "Episode 500\tAverage Score: 493.12\n", "Episode 500\tAverage Score: 493.12\n", "Episode 500\tAverage Score: 493.12\n", "Episode 500\tAverage Score: 493.12\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 488.95\n", "Episode 500\tAverage Score: 484.67\n", "Episode 500\tAverage Score: 484.67\n", "Episode 500\tAverage Score: 484.67\n", "Episode 500\tAverage Score: 484.67\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 480.52\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 476.39\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 478.97\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 479.11\n", "Episode 500\tAverage Score: 483.27\n", "Episode 500\tAverage Score: 483.27\n", "Episode 500\tAverage Score: 483.27\n", "Episode 500\tAverage Score: 483.27\n", "Episode 500\tAverage Score: 483.27\n", "Episode 500\tAverage Score: 487.44\n", "Episode 500\tAverage Score: 487.44\n", "Episode 500\tAverage Score: 487.44\n", "Episode 500\tAverage Score: 487.44\n", "Episode 500\tAverage Score: 487.44\n", "Episode 300\tAverage Score: 485.44\n", "Episode 500\tAverage Score: 485.44\n", "Episode 500\tAverage Score: 489.72\n", "Episode 500\tAverage Score: 489.72\n", "Episode 500\tAverage Score: 489.72\n", "Episode 500\tAverage Score: 489.72\n", "Episode 500\tAverage Score: 489.72\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 493.87\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 498.00\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 493.24\n", "Episode 500\tAverage Score: 488.47\n", "Episode 500\tAverage Score: 483.65\n", "Episode 500\tAverage Score: 483.65\n", "Episode 500\tAverage Score: 483.65\n", "Episode 500\tAverage Score: 483.65\n", "Episode 500\tAverage Score: 466.97\n", "Episode 500\tAverage Score: 460.99\n", "Episode 500\tAverage Score: 460.99\n", "Episode 500\tAverage Score: 460.99\n", "Episode 500\tAverage Score: 456.25\n", "Episode 500\tAverage Score: 456.25\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 500\tAverage Score: 451.43\n", "Episode 200\tAverage Score: 148.79\n", "Episode 200\tAverage Score: 157.96\n", "Episode 500\tAverage Score: 190.64\n", "Episode 500\tAverage Score: 194.26\n", "Episode 500\tAverage Score: 197.86\n", "Episode 500\tAverage Score: 201.48\n", "Episode 500\tAverage Score: 205.15\n", "Episode 500\tAverage Score: 208.76\n", "Episode 500\tAverage Score: 212.41\n", "Episode 500\tAverage Score: 216.13\n", "Episode 500\tAverage Score: 219.72\n", "Episode 500\tAverage Score: 223.56\n", "Episode 500\tAverage Score: 227.23\n", "Episode 500\tAverage Score: 230.90\n", "Episode 500\tAverage Score: 234.61\n", "Episode 500\tAverage Score: 238.32\n", "Episode 500\tAverage Score: 241.99\n", "Episode 500\tAverage Score: 245.78\n", "Episode 500\tAverage Score: 249.43\n", "Episode 500\tAverage Score: 253.18\n", "Episode 500\tAverage Score: 256.85\n", "Episode 500\tAverage Score: 260.43\n", "Episode 500\tAverage Score: 263.94\n", "Episode 500\tAverage Score: 267.68\n", "Episode 500\tAverage Score: 271.27\n", "Episode 500\tAverage Score: 274.87\n", "Episode 500\tAverage Score: 278.51\n", "Episode 500\tAverage Score: 282.18\n", "Episode 500\tAverage Score: 285.67\n", "Episode 500\tAverage Score: 289.04\n", "Episode 500\tAverage Score: 292.48\n", "Episode 500\tAverage Score: 295.88\n", "Episode 500\tAverage Score: 299.61\n", "Episode 500\tAverage Score: 302.84\n", "Episode 500\tAverage Score: 305.97\n", "Episode 500\tAverage Score: 309.13\n", "Episode 500\tAverage Score: 312.46\n", "Episode 500\tAverage Score: 315.80\n", "Episode 500\tAverage Score: 319.12\n", "Episode 500\tAverage Score: 321.31\n", "Episode 500\tAverage Score: 324.54\n", "Episode 500\tAverage Score: 327.67\n", "Episode 500\tAverage Score: 330.83\n", "Episode 500\tAverage Score: 333.27\n", "Episode 500\tAverage Score: 336.25\n", "Episode 500\tAverage Score: 339.31\n", "Episode 500\tAverage Score: 342.54\n" ] } ], "source": [ "scores = reinforce(\n", " cartpole_policy,\n", " cartpole_optimizer,\n", " cartpole_hyperparameters[\"n_training_episodes\"],\n", " cartpole_hyperparameters[\"max_t\"],\n", " cartpole_hyperparameters[\"gamma\"],\n", " 100,\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def evaluate_agent(env, max_steps, n_eval_episodes, policy):\n", " \"\"\"\n", " Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.\n", " :param env: The evaluation environment\n", " :param n_eval_episodes: Number of episode to evaluate the agent\n", " :param policy: The Reinforce agent\n", " \"\"\"\n", " episode_rewards = []\n", " for episode in range(n_eval_episodes):\n", " state = env.reset()\n", " step = 0\n", " done = False\n", " total_rewards_ep = 0\n", "\n", " for step in range(max_steps):\n", " action, _ = policy.act(state)\n", " new_state, reward, done, info = env.step(action)\n", " total_rewards_ep += reward\n", "\n", " if done:\n", " break\n", " state = new_state\n", " episode_rewards.append(total_rewards_ep)\n", " mean_reward = np.mean(episode_rewards)\n", " std_reward = np.std(episode_rewards)\n", "\n", " return mean_reward, std_reward" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(439.8, 102.12325885908655)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "evaluate_agent(\n", " eval_env, cartpole_hyperparameters[\"max_t\"], cartpole_hyperparameters[\"n_evaluation_episodes\"], cartpole_policy\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import HfApi, snapshot_download\n", "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n", "\n", "from pathlib import Path\n", "import datetime\n", "import json\n", "import imageio\n", "\n", "import tempfile\n", "\n", "import os" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def record_video(env, policy, out_directory, fps=30):\n", " \"\"\"\n", " Generate a replay video of the agent\n", " :param env\n", " :param Qtable: Qtable of our agent\n", " :param out_directory\n", " :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1)\n", " \"\"\"\n", " images = []\n", " done = False\n", " state = env.reset()\n", " img = env.render(mode=\"rgb_array\")\n", " images.append(img)\n", " while not done:\n", " # Take the action (index) that have the maximum expected future reward given that state\n", " action, _ = policy.act(state)\n", " state, reward, done, info = env.step(action) # We directly put next_state = state for recording logic\n", " img = env.render(mode=\"rgb_array\")\n", " images.append(img)\n", " imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import HfApi, snapshot_download\n", "from huggingface_hub.repocard import metadata_eval_result, metadata_save\n", "\n", "from pathlib import Path\n", "import datetime\n", "import json\n", "import imageio\n", "\n", "import tempfile\n", "\n", "import os" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".\n" ] } ], "source": [ "def push_to_hub(repo_id,\n", " model,\n", " hyperparameters,\n", " eval_env,\n", " video_fps=30\n", " ):\n", " \"\"\"\n", " Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n", " This method does the complete pipeline:\n", " - It evaluates the model\n", " - It generates the model card\n", " - It generates a replay video of the agent\n", " - It pushes everything to the Hub\n", "\n", " :param repo_id: repo_id: id of the model repository from the Hugging Face Hub\n", " :param model: the pytorch model we want to save\n", " :param hyperparameters: training hyperparameters\n", " :param eval_env: evaluation environment\n", " :param video_fps: how many frame per seconds to record our video replay\n", " \"\"\"\n", "\n", " _, repo_name = repo_id.split(\"/\")\n", " api = HfApi()\n", "\n", " # Step 1: Create the repo\n", " repo_url = api.create_repo(\n", " repo_id=repo_id,\n", " exist_ok=True,\n", " )\n", "\n", " with tempfile.TemporaryDirectory() as tmpdirname:\n", " local_directory = Path(\"./\")\n", "\n", " # Step 2: Save the model\n", " torch.save(model, local_directory / \"model.pt\")\n", "\n", " # Step 3: Save the hyperparameters to JSON\n", " with open(local_directory / \"hyperparameters.json\", \"w\") as outfile:\n", " json.dump(hyperparameters, outfile)\n", "\n", " # Step 4: Evaluate the model and build JSON\n", " mean_reward, std_reward = evaluate_agent(eval_env,\n", " hyperparameters[\"max_t\"],\n", " hyperparameters[\"n_evaluation_episodes\"],\n", " model)\n", " # Get datetime\n", " eval_datetime = datetime.datetime.now()\n", " eval_form_datetime = eval_datetime.isoformat()\n", "\n", " evaluate_data = {\n", " \"env_id\": hyperparameters[\"env_id\"],\n", " \"mean_reward\": mean_reward,\n", " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n", " \"eval_datetime\": eval_form_datetime,\n", " }\n", "\n", " # Write a JSON file\n", " with open(local_directory / \"results.json\", \"w\") as outfile:\n", " json.dump(evaluate_data, outfile)\n", "\n", " # Step 5: Create the model card\n", " env_name = hyperparameters[\"env_id\"]\n", "\n", " metadata = {}\n", " metadata[\"tags\"] = [\n", " env_name,\n", " \"reinforce\",\n", " \"reinforcement-learning\",\n", " \"custom-implementation\",\n", " \"deep-rl-class\"\n", " ]\n", "\n", " # Add metrics\n", " eval = metadata_eval_result(\n", " model_pretty_name=repo_name,\n", " task_pretty_name=\"reinforcement-learning\",\n", " task_id=\"reinforcement-learning\",\n", " metrics_pretty_name=\"mean_reward\",\n", " metrics_id=\"mean_reward\",\n", " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n", " dataset_pretty_name=env_name,\n", " dataset_id=env_name,\n", " )\n", "\n", " # Merges both dictionaries\n", " metadata = {**metadata, **eval}\n", "\n", " model_card = f\"\"\"\n", " # **Reinforce** Agent playing **{env_id}**\n", " This is a trained model of a **Reinforce** agent playing **{env_id}** .\n", " To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n", " \"\"\"\n", "\n", " readme_path = local_directory / \"README.md\"\n", " readme = \"\"\n", " if readme_path.exists():\n", " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n", " readme = f.read()\n", " else:\n", " readme = model_card\n", "\n", " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n", " f.write(readme)\n", "\n", " # Save our metrics to Readme metadata\n", " metadata_save(readme_path, metadata)\n", "\n", " # Step 6: Record a video\n", " video_path = local_directory / \"replay.mp4\"\n", " record_video(env, model, video_path, video_fps)\n", "\n", " # Step 7. Push everything to the Hub\n", " api.upload_folder(\n", " repo_id=repo_id,\n", " folder_path=local_directory,\n", " path_in_repo=\".\",\n", " )\n", "\n", " print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.10 ('torch_venv')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "745a3b3e3fb7ac09f0ebb6d5eb47d006584e16db5d9df6f9a8b654baa561b29f" } } }, "nbformat": 4, "nbformat_minor": 2 }