Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- .python-version +1 -0
- .vscode/settings.json +10 -0
- README.md +0 -0
- poetry.lock +0 -0
- ppo-LunarLander-v2.pt +3 -0
- ppo.py +133 -0
- ppo_old.py +790 -0
- pyproject.toml +23 -0
- ruff.toml +104 -0
- runs/CartPole-v1_ppo_1_1776758028/events.out.tfevents.1776758028.BL-LFERNANDEZ.local.eurecat.org.10534.0 +3 -0
- runs/CartPole-v1_ppo_1_1776759190/events.out.tfevents.1776759190.BL-LFERNANDEZ.local.eurecat.org.13333.0 +3 -0
- runs/CartPole-v1_ppo_1_1776759930/events.out.tfevents.1776759930.BL-LFERNANDEZ.local.eurecat.org.16675.0 +3 -0
- runs/CartPole-v1_ppo_1_1776760089/events.out.tfevents.1776760089.BL-LFERNANDEZ.local.eurecat.org.17273.0 +3 -0
- runs/CartPole-v1_ppo_1_1776760573/events.out.tfevents.1776760573.BL-LFERNANDEZ.local.eurecat.org.19521.0 +3 -0
- runs/CartPole-v1_ppo_1_1776760603/events.out.tfevents.1776760800.BL-LFERNANDEZ.local.eurecat.org.19762.0 +3 -0
- runs/CartPole-v1_ppo_1_1776761629/events.out.tfevents.1776761629.BL-LFERNANDEZ.local.eurecat.org.40907.0 +3 -0
- runs/CartPole-v1_ppo_1_1776764954/events.out.tfevents.1776764954.BL-LFERNANDEZ.local.eurecat.org.45905.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765078/events.out.tfevents.1776765078.BL-LFERNANDEZ.local.eurecat.org.46277.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765121/events.out.tfevents.1776765121.BL-LFERNANDEZ.local.eurecat.org.46547.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765317/events.out.tfevents.1776765317.BL-LFERNANDEZ.local.eurecat.org.47512.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765438/events.out.tfevents.1776765438.BL-LFERNANDEZ.local.eurecat.org.47896.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765498/events.out.tfevents.1776765498.BL-LFERNANDEZ.local.eurecat.org.48020.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765547/events.out.tfevents.1776765547.BL-LFERNANDEZ.local.eurecat.org.48309.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765580/events.out.tfevents.1776765580.BL-LFERNANDEZ.local.eurecat.org.48524.0 +3 -0
- runs/CartPole-v1_ppo_1_1776765943/events.out.tfevents.1776765945.BL-LFERNANDEZ.local.eurecat.org.49910.0 +3 -0
- runs/CartPole-v1_ppo_1_1776766122/events.out.tfevents.1776766124.BL-LFERNANDEZ.local.eurecat.org.50518.0 +3 -0
- runs/CartPole-v1_ppo_1_1776766281/events.out.tfevents.1776766283.BL-LFERNANDEZ.local.eurecat.org.51127.0 +3 -0
- runs/CartPole-v1_ppo_1_1776766423/events.out.tfevents.1776766423.BL-LFERNANDEZ.local.eurecat.org.51653.0 +3 -0
- runs/CartPole-v1_ppo_1_1776766445/events.out.tfevents.1776766445.BL-LFERNANDEZ.local.eurecat.org.51862.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767609/events.out.tfevents.1776767609.BL-LFERNANDEZ.local.eurecat.org.55609.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767664/events.out.tfevents.1776767664.BL-LFERNANDEZ.local.eurecat.org.55869.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767691/events.out.tfevents.1776767691.BL-LFERNANDEZ.local.eurecat.org.55998.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767743/events.out.tfevents.1776767743.BL-LFERNANDEZ.local.eurecat.org.56120.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767808/events.out.tfevents.1776767808.BL-LFERNANDEZ.local.eurecat.org.56378.0 +3 -0
- runs/CartPole-v1_ppo_1_1776767863/events.out.tfevents.1776767863.BL-LFERNANDEZ.local.eurecat.org.56483.0 +3 -0
- runs/CartPole-v1_ppo_1_1776768348/events.out.tfevents.1776768348.BL-LFERNANDEZ.local.eurecat.org.57900.0 +3 -0
- runs/CartPole-v1_ppo_1_1776768490/events.out.tfevents.1776768490.BL-LFERNANDEZ.local.eurecat.org.58209.0 +3 -0
- runs/CartPole-v1_ppo_1_1776768658/events.out.tfevents.1776768658.BL-LFERNANDEZ.local.eurecat.org.58694.0 +3 -0
- runs/CartPole-v1_ppo_1_1776768770/events.out.tfevents.1776768770.BL-LFERNANDEZ.local.eurecat.org.59216.0 +3 -0
- runs/CartPole-v1_ppo_1_1776768821/events.out.tfevents.1776768821.BL-LFERNANDEZ.local.eurecat.org.59474.0 +3 -0
- runs/CartPole-v1_ppo_1_1776769606/events.out.tfevents.1776769606.BL-LFERNANDEZ.local.eurecat.org.62559.0 +3 -0
- runs/CartPole-v1_ppo_1_1776769691/events.out.tfevents.1776769691.BL-LFERNANDEZ.local.eurecat.org.62771.0 +3 -0
- runs/CartPole-v1_ppo_1_1776770441/events.out.tfevents.1776770441.BL-LFERNANDEZ.local.eurecat.org.64367.0 +3 -0
- runs/CartPole-v1_ppo_1_1776770456/events.out.tfevents.1776770456.BL-LFERNANDEZ.local.eurecat.org.64462.0 +3 -0
- runs/CartPole-v1_ppo_1_1776779529/events.out.tfevents.1776779529.BL-LFERNANDEZ.local.eurecat.org.84023.0 +3 -0
- runs/CartPole-v1_ppo_1_1776779596/events.out.tfevents.1776779596.BL-LFERNANDEZ.local.eurecat.org.106839.0 +3 -0
- runs/CartPole-v1_ppo_1_1776779670/events.out.tfevents.1776779672.BL-LFERNANDEZ.local.eurecat.org.209000.0 +3 -0
- runs/CartPole-v1_ppo_1_1776779742/events.out.tfevents.1776779745.BL-LFERNANDEZ.local.eurecat.org.209486.0 +3 -0
- runs/CartPole-v1_ppo_1_1776779888/events.out.tfevents.1776779890.BL-LFERNANDEZ.local.eurecat.org.210132.0 +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
wandb/run-20260421_160121-lnmebn7s/run-lnmebn7s.wandb filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
wandb/run-20260421_165557-ngbvlof1/run-ngbvlof1.wandb filter=lfs diff=lfs merge=lfs -text
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.9.5
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"editor.formatOnSave": true,
|
| 3 |
+
"editor.codeActionsOnSave": {
|
| 4 |
+
"source.fixAll": "explicit",
|
| 5 |
+
"source.organizeImports": "explicit"
|
| 6 |
+
},
|
| 7 |
+
"[python]": {
|
| 8 |
+
"editor.defaultFormatter": "charliermarsh.ruff"
|
| 9 |
+
}
|
| 10 |
+
}
|
README.md
ADDED
|
File without changes
|
poetry.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ppo-LunarLander-v2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae83d259ed07741805748cf65c6a7c0b1279ccf2fa154cac1d305453e0becfc8
|
| 3 |
+
size 44165
|
ppo.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np # Ensure numpy is imported
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.distributions.categorical import Categorical
|
| 9 |
+
|
| 10 |
+
# --- Hyperparameters ---
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
env_id = "LunarLander-v2"
|
| 13 |
+
total_timesteps = 500000
|
| 14 |
+
learning_rate = 2.5e-4
|
| 15 |
+
num_envs = 4
|
| 16 |
+
num_steps = 128
|
| 17 |
+
batch_size = num_envs * num_steps
|
| 18 |
+
minibatch_size = 32
|
| 19 |
+
update_epochs = 4
|
| 20 |
+
clip_coef = 0.2
|
| 21 |
+
ent_coef = 0.01
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# --- Model Architecture ---
|
| 25 |
+
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
| 26 |
+
torch.nn.init.orthogonal_(layer.weight, std)
|
| 27 |
+
torch.nn.init.constant_(layer.bias, bias_const)
|
| 28 |
+
return layer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Agent(nn.Module):
|
| 32 |
+
def __init__(self, envs):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.critic = nn.Sequential(
|
| 35 |
+
layer_init(
|
| 36 |
+
nn.Linear(
|
| 37 |
+
np.array(envs.single_observation_space.shape).prod(), 64
|
| 38 |
+
)
|
| 39 |
+
),
|
| 40 |
+
nn.Tanh(),
|
| 41 |
+
layer_init(nn.Linear(64, 64)),
|
| 42 |
+
nn.Tanh(),
|
| 43 |
+
layer_init(nn.Linear(64, 1), std=1.0),
|
| 44 |
+
)
|
| 45 |
+
self.actor = nn.Sequential(
|
| 46 |
+
layer_init(
|
| 47 |
+
nn.Linear(
|
| 48 |
+
np.array(envs.single_observation_space.shape).prod(), 64
|
| 49 |
+
)
|
| 50 |
+
),
|
| 51 |
+
nn.Tanh(),
|
| 52 |
+
layer_init(nn.Linear(64, 64)),
|
| 53 |
+
nn.Tanh(),
|
| 54 |
+
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def get_value(self, x):
|
| 58 |
+
return self.critic(x)
|
| 59 |
+
|
| 60 |
+
def get_action_and_value(self, x, action=None):
|
| 61 |
+
logits = self.actor(x)
|
| 62 |
+
probs = Categorical(logits=logits)
|
| 63 |
+
if action is None:
|
| 64 |
+
action = probs.sample()
|
| 65 |
+
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# --- Training Loop Setup ---
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
envs = gym.vector.SyncVectorEnv(
|
| 71 |
+
[lambda: gym.make(env_id) for _ in range(num_envs)]
|
| 72 |
+
)
|
| 73 |
+
agent = Agent(envs).to(device)
|
| 74 |
+
optimizer = optim.Adam(agent.parameters(), lr=learning_rate, eps=1e-5)
|
| 75 |
+
|
| 76 |
+
# Storage setup
|
| 77 |
+
obs = torch.zeros(
|
| 78 |
+
(num_steps, num_envs) + envs.single_observation_space.shape
|
| 79 |
+
).to(device)
|
| 80 |
+
actions = torch.zeros(
|
| 81 |
+
(num_steps, num_envs) + envs.single_action_space.shape
|
| 82 |
+
).to(device)
|
| 83 |
+
logprobs = torch.zeros((num_steps, num_envs)).to(device)
|
| 84 |
+
rewards = torch.zeros((num_steps, num_envs)).to(device)
|
| 85 |
+
dones = torch.zeros((num_steps, num_envs)).to(device)
|
| 86 |
+
values = torch.zeros((num_steps, num_envs)).to(device)
|
| 87 |
+
|
| 88 |
+
global_step = 0
|
| 89 |
+
next_obs, _ = envs.reset()
|
| 90 |
+
next_obs = torch.Tensor(next_obs).to(device)
|
| 91 |
+
next_done = torch.zeros(num_envs).to(device)
|
| 92 |
+
|
| 93 |
+
for iteration in range(1, total_timesteps // batch_size + 1):
|
| 94 |
+
# 1. Rollout phase
|
| 95 |
+
for step in range(num_steps):
|
| 96 |
+
global_step += num_envs
|
| 97 |
+
obs[step] = next_obs
|
| 98 |
+
dones[step] = next_done
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
action, logprob, _, value = agent.get_action_and_value(
|
| 102 |
+
next_obs
|
| 103 |
+
)
|
| 104 |
+
values[step] = value.flatten()
|
| 105 |
+
actions[step] = action
|
| 106 |
+
logprobs[step] = logprob
|
| 107 |
+
|
| 108 |
+
next_obs, reward, terminations, activations, infos = envs.step(
|
| 109 |
+
action.cpu().numpy()
|
| 110 |
+
)
|
| 111 |
+
next_done = np.logical_or(terminations, activations)
|
| 112 |
+
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
| 113 |
+
next_obs, next_done = (
|
| 114 |
+
torch.Tensor(next_obs).to(device),
|
| 115 |
+
torch.Tensor(next_done).to(device),
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 2. Advantage Calculation (GAE can be added here, simplified for brevity)
|
| 119 |
+
# 3. PPO Update Logic (Actor and Critic Loss)
|
| 120 |
+
# ... [Policy Update Logic goes here] ...
|
| 121 |
+
|
| 122 |
+
print(f"Step: {global_step} | Training...")
|
| 123 |
+
|
| 124 |
+
# --- Save and Upload ---
|
| 125 |
+
model_name = "ppo-LunarLander-v2"
|
| 126 |
+
torch.save(agent.state_dict(), f"{model_name}.pt")
|
| 127 |
+
|
| 128 |
+
# Upload to Hub
|
| 129 |
+
# Replace 'your-username' with your actual HF username
|
| 130 |
+
repo_id = f"your-username/{model_name}"
|
| 131 |
+
push_to_hub_fast(
|
| 132 |
+
repo_id=repo_id, folder_path="./", token=os.getenv("HF_TOKEN")
|
| 133 |
+
)
|
ppo_old.py
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
from distutils.util import strtobool
|
| 6 |
+
|
| 7 |
+
import gym
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from torch.distributions.categorical import Categorical
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_env(gym_id, seed, idx, capture_video, run_name):
|
| 17 |
+
def thunk():
|
| 18 |
+
env = gym.make(gym_id)
|
| 19 |
+
env = gym.wrappers.RecordEpisodeStatistics(env)
|
| 20 |
+
if capture_video:
|
| 21 |
+
if idx == 0:
|
| 22 |
+
env = gym.wrappers.RecordVideo(
|
| 23 |
+
env,
|
| 24 |
+
f"videos/{run_name}",
|
| 25 |
+
episode_trigger=lambda t: t % 1000 == 0,
|
| 26 |
+
)
|
| 27 |
+
env.seed(seed)
|
| 28 |
+
env.action_space.seed(seed)
|
| 29 |
+
env.observation_space.seed(seed)
|
| 30 |
+
return env
|
| 31 |
+
|
| 32 |
+
return thunk
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
| 36 |
+
torch.nn.init.orthogonal_(layer.weight, std)
|
| 37 |
+
torch.nn.init.constant_(layer.bias, bias_const)
|
| 38 |
+
return layer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Agent(nn.Module):
|
| 42 |
+
def __init__(self, envs):
|
| 43 |
+
super(Agent, self).__init__()
|
| 44 |
+
self.critic = nn.Sequential(
|
| 45 |
+
layer_init(
|
| 46 |
+
nn.Linear(
|
| 47 |
+
np.array(envs.single_observation_space.shape).prod(), 64
|
| 48 |
+
)
|
| 49 |
+
),
|
| 50 |
+
nn.Tanh(),
|
| 51 |
+
layer_init(nn.Linear(64, 64)),
|
| 52 |
+
nn.Tanh(),
|
| 53 |
+
layer_init(nn.Linear(64, 1), std=1.0),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.actor = nn.Sequential(
|
| 57 |
+
layer_init(
|
| 58 |
+
nn.Linear(
|
| 59 |
+
np.array(envs.single_observation_space.shape).prod(), 64
|
| 60 |
+
)
|
| 61 |
+
),
|
| 62 |
+
nn.Tanh(),
|
| 63 |
+
layer_init(nn.Linear(64, 64)),
|
| 64 |
+
nn.Tanh(),
|
| 65 |
+
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def get_value(self, x):
|
| 69 |
+
return self.critic(x)
|
| 70 |
+
|
| 71 |
+
def get_action_and_value(self, x, action=None):
|
| 72 |
+
logits = self.actor(x)
|
| 73 |
+
probs = Categorical(logits=logits)
|
| 74 |
+
if action is None:
|
| 75 |
+
action = probs.sample()
|
| 76 |
+
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def parse_args():
|
| 80 |
+
parser = argparse.ArgumentParser()
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--exp-name",
|
| 83 |
+
type=str,
|
| 84 |
+
default=os.path.basename(__file__).rstrip(".py"),
|
| 85 |
+
help="the name of this experiment",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--gym-id",
|
| 89 |
+
type=str,
|
| 90 |
+
default="CartPole-v1",
|
| 91 |
+
help="the id of the gym environment",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--learning-rate",
|
| 95 |
+
type=float,
|
| 96 |
+
default=2.5e-4,
|
| 97 |
+
help="the learning rate of the optimizer",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--seed", type=int, default=1, help="seed of the experiment"
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--total-timesteps",
|
| 104 |
+
type=int,
|
| 105 |
+
default=25000,
|
| 106 |
+
help="total timesteps of the experiments",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--torch-deterministic",
|
| 110 |
+
type=lambda x: bool(strtobool(x)),
|
| 111 |
+
default=True,
|
| 112 |
+
nargs="?",
|
| 113 |
+
const=True,
|
| 114 |
+
help="if toggled, `torch.backeds.cudnn.deterministic=False`",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--cuda",
|
| 118 |
+
type=lambda x: bool(strtobool(x)),
|
| 119 |
+
default=True,
|
| 120 |
+
nargs="?",
|
| 121 |
+
const=True,
|
| 122 |
+
help="if toggled, cuda will not be enabled by default",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--track",
|
| 126 |
+
type=lambda x: bool(strtobool(x)),
|
| 127 |
+
default=False,
|
| 128 |
+
nargs="?",
|
| 129 |
+
const=True,
|
| 130 |
+
help="if toggled, this experiment will be tracked with Weights and Biases",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--wandb-project-name",
|
| 134 |
+
type=str,
|
| 135 |
+
default="cleanRL",
|
| 136 |
+
help="the wandb's project name",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--wandb-entity",
|
| 140 |
+
type=str,
|
| 141 |
+
default=None,
|
| 142 |
+
help="the entity (team) of wandb's project",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--num-envs",
|
| 146 |
+
type=int,
|
| 147 |
+
default=4,
|
| 148 |
+
help="the number of parallel game environments",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--capture-video",
|
| 152 |
+
type=lambda x: bool(strtobool(x)),
|
| 153 |
+
default=False,
|
| 154 |
+
nargs="?",
|
| 155 |
+
const=True,
|
| 156 |
+
help="whether to capture videos of the agent performances (check out `videos` folder)",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--num-steps",
|
| 160 |
+
type=int,
|
| 161 |
+
default=128,
|
| 162 |
+
help="the number of steps to run in each environment per policy rollout",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--anneal-lr",
|
| 166 |
+
type=lambda x: bool(strtobool(x)),
|
| 167 |
+
default=True,
|
| 168 |
+
nargs="?",
|
| 169 |
+
const=True,
|
| 170 |
+
help="Toggle learning rate annealing for policy and value networks",
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--gae",
|
| 174 |
+
type=lambda x: bool(strtobool(x)),
|
| 175 |
+
default=True,
|
| 176 |
+
nargs="?",
|
| 177 |
+
const=True,
|
| 178 |
+
help="Use GAE for advantage computation",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--gamma", type=float, default=0.99, help="the discount factor gamma"
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--gae-lambda",
|
| 185 |
+
type=float,
|
| 186 |
+
default=0.95,
|
| 187 |
+
help="the lambda for the general advantage estimation",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--num-minibatches",
|
| 191 |
+
type=int,
|
| 192 |
+
default=4,
|
| 193 |
+
help="the number of mini-batches",
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--update-epochs",
|
| 197 |
+
type=int,
|
| 198 |
+
default=4,
|
| 199 |
+
help="the K epochs to update the policy",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--norm-adv",
|
| 203 |
+
type=lambda x: bool(strtobool(x)),
|
| 204 |
+
default=True,
|
| 205 |
+
nargs="?",
|
| 206 |
+
const=True,
|
| 207 |
+
help="Toggles advantages normalization",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--clip-coef",
|
| 211 |
+
type=float,
|
| 212 |
+
default=0.2,
|
| 213 |
+
help="the surrogate clipping coefficient",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--clip-vloss",
|
| 217 |
+
type=lambda x: bool(strtobool(x)),
|
| 218 |
+
default=True,
|
| 219 |
+
nargs="?",
|
| 220 |
+
const=True,
|
| 221 |
+
help="Toggles wheter or not to use a clipped loss for the value function, as per the paper",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--ent-coef",
|
| 225 |
+
type=float,
|
| 226 |
+
default=0.01,
|
| 227 |
+
help="coefficient of the entropy",
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--vf-coef",
|
| 231 |
+
type=float,
|
| 232 |
+
default=0.5,
|
| 233 |
+
help="coefficient of the value function",
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--max-grad-norm",
|
| 237 |
+
type=float,
|
| 238 |
+
default=0.5,
|
| 239 |
+
help="the maximum norm for the gradient clipping",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--target-kl",
|
| 243 |
+
type=float,
|
| 244 |
+
default=None,
|
| 245 |
+
help="the target KL divergence threshold",
|
| 246 |
+
)
|
| 247 |
+
args = parser.parse_args()
|
| 248 |
+
args.batch_size = int(args.num_envs * args.num_steps)
|
| 249 |
+
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
| 250 |
+
return args
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
args = parse_args()
|
| 255 |
+
print(args)
|
| 256 |
+
run_name = f"{args.gym_id}_{args.exp_name}_{args.seed}_{int(time.time())}"
|
| 257 |
+
if args.track:
|
| 258 |
+
import wandb
|
| 259 |
+
|
| 260 |
+
wandb.init(
|
| 261 |
+
project=args.wandb_project_name,
|
| 262 |
+
entity=args.wandb_entity,
|
| 263 |
+
sync_tensorboard=True,
|
| 264 |
+
config=vars(args),
|
| 265 |
+
name=run_name,
|
| 266 |
+
monitor_gym=True,
|
| 267 |
+
save_code=True,
|
| 268 |
+
)
|
| 269 |
+
writer = SummaryWriter(f"runs/{run_name}")
|
| 270 |
+
writer.add_text(
|
| 271 |
+
"hyperparameters",
|
| 272 |
+
"|param|value|\n|-|-|\n%s"
|
| 273 |
+
% (
|
| 274 |
+
"\n".join(
|
| 275 |
+
[f"|{key}|{value}|" for key, value in vars(args).items()]
|
| 276 |
+
)
|
| 277 |
+
),
|
| 278 |
+
)
|
| 279 |
+
random.seed(args.seed)
|
| 280 |
+
np.random.seed(args.seed)
|
| 281 |
+
torch.manual_seed(args.seed)
|
| 282 |
+
torch.backends.cudnn.deterministic = args.torch_deterministic
|
| 283 |
+
|
| 284 |
+
device = torch.device(
|
| 285 |
+
"cuda" if torch.cuda.is_available() and args.cuda else "cpu"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
envs = gym.vector.SyncVectorEnv(
|
| 289 |
+
[
|
| 290 |
+
make_env(
|
| 291 |
+
args.gym_id, args.seed + i, i, args.capture_video, run_name
|
| 292 |
+
)
|
| 293 |
+
for i in range(args.num_envs)
|
| 294 |
+
]
|
| 295 |
+
)
|
| 296 |
+
assert isinstance(envs.single_action_space, gym.spaces.Discrete), (
|
| 297 |
+
"only discrete action space is supported"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
agent = Agent(envs).to(device)
|
| 301 |
+
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
|
| 302 |
+
|
| 303 |
+
# ALGO Logic: Storage setup
|
| 304 |
+
obs = torch.zeros(
|
| 305 |
+
(args.num_steps, args.num_envs) + envs.single_observation_space.shape
|
| 306 |
+
).to(device)
|
| 307 |
+
actions = torch.zeros(
|
| 308 |
+
(args.num_steps, args.num_envs) + envs.single_action_space.shape
|
| 309 |
+
).to(device)
|
| 310 |
+
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
| 311 |
+
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
| 312 |
+
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
| 313 |
+
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
| 314 |
+
|
| 315 |
+
# TRY NOT TO MODIFY: start the game
|
| 316 |
+
global_step = 0
|
| 317 |
+
start_time = time.time()
|
| 318 |
+
next_obs = torch.Tensor(envs.reset()).to(device)
|
| 319 |
+
next_done = torch.zeros(args.num_envs).to(device)
|
| 320 |
+
num_updates = args.total_timesteps // args.batch_size
|
| 321 |
+
|
| 322 |
+
for update in range(1, num_updates + 1):
|
| 323 |
+
# Annealing the rate if instructed to do so.abs
|
| 324 |
+
if args.anneal_lr:
|
| 325 |
+
frac = 1.0 - (update - 1.0) / num_updates
|
| 326 |
+
lrnow = frac * args.learning_rate
|
| 327 |
+
optimizer.param_groups[0]["lr"] = lrnow
|
| 328 |
+
|
| 329 |
+
for step in range(args.num_steps):
|
| 330 |
+
global_step += 1 * args.num_envs
|
| 331 |
+
obs[step] = next_obs
|
| 332 |
+
dones[step] = next_done
|
| 333 |
+
|
| 334 |
+
# ALGO LOGIC : action logic
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
action, logprob, _, value = agent.get_action_and_value(
|
| 337 |
+
next_obs
|
| 338 |
+
)
|
| 339 |
+
values[step] = value.flatten()
|
| 340 |
+
actions[step] = action
|
| 341 |
+
logprobs[step] = logprob
|
| 342 |
+
|
| 343 |
+
# TRY NOT TO MODIFY: execute the game and log data.abs
|
| 344 |
+
next_obs, reward, done, info = envs.step(action.cpu().numpy())
|
| 345 |
+
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
| 346 |
+
next_obs, next_done = (
|
| 347 |
+
torch.Tensor(next_obs).to(device),
|
| 348 |
+
torch.Tensor(done).to(device),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
if isinstance(info, dict) and "episode" in info:
|
| 352 |
+
for item in info["episode"]:
|
| 353 |
+
if item is not None:
|
| 354 |
+
print(
|
| 355 |
+
f"global_step={global_step}, episodic_return={item['r']}"
|
| 356 |
+
)
|
| 357 |
+
writer.add_scalar(
|
| 358 |
+
"charts/episodic_return", item["r"], global_step
|
| 359 |
+
)
|
| 360 |
+
writer.add_scalar(
|
| 361 |
+
"charts/episodic_length", item["l"], global_step
|
| 362 |
+
)
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
# bootstrap reward if not done
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
next_value = agent.get_value(next_obs).reshape(1, -1)
|
| 368 |
+
if args.gae:
|
| 369 |
+
advantages = torch.zeros_like(rewards).to(device)
|
| 370 |
+
lastgaelam = 0
|
| 371 |
+
for t in reversed(range(args.num_steps)):
|
| 372 |
+
if t == args.num_steps - 1:
|
| 373 |
+
nextnonterminal = 1.0 - next_done
|
| 374 |
+
nextvalues = next_value
|
| 375 |
+
else:
|
| 376 |
+
nextnonterminal = 1.0 - dones[t + 1]
|
| 377 |
+
nextvalues = values[t + 1]
|
| 378 |
+
delta = (
|
| 379 |
+
rewards[t]
|
| 380 |
+
+ args.gamma * nextvalues * nextnonterminal
|
| 381 |
+
- values[t]
|
| 382 |
+
)
|
| 383 |
+
advantages[t] = lastgaelam = (
|
| 384 |
+
delta
|
| 385 |
+
+ args.gamma
|
| 386 |
+
* args.gae_lambda
|
| 387 |
+
* nextnonterminal
|
| 388 |
+
* lastgaelam
|
| 389 |
+
)
|
| 390 |
+
returns = advantages + values
|
| 391 |
+
else:
|
| 392 |
+
returns = torch.zeros_like(rewards).to(device)
|
| 393 |
+
for t in reversed(range(args.num_steps)):
|
| 394 |
+
if t == args.num_steps - 1:
|
| 395 |
+
nextnonterminal = 1.0 - next_done
|
| 396 |
+
next_return = next_value
|
| 397 |
+
else:
|
| 398 |
+
nextnonterminal = 1.0 - dones[t + 1]
|
| 399 |
+
next_return = returns[t + 1]
|
| 400 |
+
returns[t] = (
|
| 401 |
+
rewards[t]
|
| 402 |
+
+ args.gamma * nextnonterminal * next_return
|
| 403 |
+
)
|
| 404 |
+
advantages = returns - values
|
| 405 |
+
|
| 406 |
+
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
|
| 407 |
+
b_logprobs = logprobs.reshape(-1)
|
| 408 |
+
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
|
| 409 |
+
b_advantages = advantages.reshape(-1)
|
| 410 |
+
b_returns = returns.reshape(-1)
|
| 411 |
+
b_values = values.reshape(-1)
|
| 412 |
+
|
| 413 |
+
# Optimizaing the policy and value network
|
| 414 |
+
b_inds = np.arange(args.batch_size)
|
| 415 |
+
clipfracs = []
|
| 416 |
+
for epoch in range(args.update_epochs):
|
| 417 |
+
np.random.shuffle(b_inds)
|
| 418 |
+
for start in range(0, args.batch_size, args.minibatch_size):
|
| 419 |
+
end = start + args.minibatch_size
|
| 420 |
+
mb_inds = b_inds[start:end]
|
| 421 |
+
|
| 422 |
+
_, newlogprob, entropy, newvalue = (
|
| 423 |
+
agent.get_action_and_value( # POSIBLE ERROR AQUI
|
| 424 |
+
b_obs[mb_inds], b_actions.long()[mb_inds]
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
logratio = newlogprob - b_logprobs[mb_inds]
|
| 428 |
+
ratio = logratio.exp()
|
| 429 |
+
|
| 430 |
+
with torch.no_grad():
|
| 431 |
+
# calculate approx kl as in http://joschu.net/blog/kl-aprox.html
|
| 432 |
+
old_approx_kl = (-logratio).mean()
|
| 433 |
+
approx_kl = ((ratio - 1) - logratio).mean()
|
| 434 |
+
clipfracs += [
|
| 435 |
+
((ratio - 1.0).abs() > args.clip_coef)
|
| 436 |
+
.float()
|
| 437 |
+
.mean()
|
| 438 |
+
]
|
| 439 |
+
|
| 440 |
+
mb_advantages = b_advantages[mb_inds]
|
| 441 |
+
if args.norm_adv:
|
| 442 |
+
mb_advantages = (
|
| 443 |
+
mb_advantages - mb_advantages.mean()
|
| 444 |
+
) / (mb_advantages.std() + 1e-8)
|
| 445 |
+
|
| 446 |
+
# Policy loss
|
| 447 |
+
pg_loss1 = -mb_advantages * ratio
|
| 448 |
+
pg_loss2 = -mb_advantages * torch.clamp(
|
| 449 |
+
ratio, 1 - args.clip_coef, 1 + args.clip_coef
|
| 450 |
+
)
|
| 451 |
+
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
|
| 452 |
+
|
| 453 |
+
# Value loss
|
| 454 |
+
newvalue = newvalue.view(-1)
|
| 455 |
+
if args.clip_vloss:
|
| 456 |
+
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
|
| 457 |
+
v_clipped = b_values[mb_inds] + torch.clamp(
|
| 458 |
+
newvalue - b_values[mb_inds],
|
| 459 |
+
-args.clip_coef,
|
| 460 |
+
args.clip_coef,
|
| 461 |
+
)
|
| 462 |
+
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
|
| 463 |
+
v_loss_max = torch.max(
|
| 464 |
+
v_loss_unclipped, v_loss_clipped
|
| 465 |
+
)
|
| 466 |
+
v_loss = 0.5 * v_loss_max.mean()
|
| 467 |
+
else:
|
| 468 |
+
v_loss = (
|
| 469 |
+
0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# Entropy loss
|
| 473 |
+
entropy_loss = entropy.mean()
|
| 474 |
+
loss = (
|
| 475 |
+
pg_loss
|
| 476 |
+
- args.ent_coef * entropy_loss
|
| 477 |
+
+ v_loss * args.vf_coef
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
optimizer.zero_grad()
|
| 481 |
+
loss.backward()
|
| 482 |
+
nn.utils.clip_grad_norm_(
|
| 483 |
+
agent.parameters(), args.max_grad_norm
|
| 484 |
+
)
|
| 485 |
+
optimizer.step()
|
| 486 |
+
|
| 487 |
+
if args.target_kl is not None:
|
| 488 |
+
if approx_kl > args.target_kl:
|
| 489 |
+
break
|
| 490 |
+
|
| 491 |
+
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
|
| 492 |
+
var_y = np.var(y_true)
|
| 493 |
+
explained_var = (
|
| 494 |
+
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
| 498 |
+
writer.add_scalar(
|
| 499 |
+
"charts/learning_rate",
|
| 500 |
+
optimizer.param_groups[0]["lr"],
|
| 501 |
+
global_step,
|
| 502 |
+
)
|
| 503 |
+
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
|
| 504 |
+
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
|
| 505 |
+
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
|
| 506 |
+
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
|
| 507 |
+
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
|
| 508 |
+
writer.add_scalar(
|
| 509 |
+
"losses/explained_variance", explained_var, global_step
|
| 510 |
+
)
|
| 511 |
+
print("SPS:", int(global_step / (time.time() - start_time)))
|
| 512 |
+
writer.add_scalar(
|
| 513 |
+
"charts/SPS",
|
| 514 |
+
int(global_step / (time.time() - start_time)),
|
| 515 |
+
global_step,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
envs.close()
|
| 519 |
+
writer.close()
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
##############################################################################
|
| 523 |
+
############################## Huggingface ###################################
|
| 524 |
+
##############################################################################
|
| 525 |
+
import datetime
|
| 526 |
+
import json
|
| 527 |
+
import shutil
|
| 528 |
+
import tempfile
|
| 529 |
+
from pathlib import Path
|
| 530 |
+
|
| 531 |
+
import imageio
|
| 532 |
+
from huggingface_hub import HfApi, upload_folder
|
| 533 |
+
from huggingface_hub.repocard import metadata_eval_result, metadata_save
|
| 534 |
+
from wasabi import Printer
|
| 535 |
+
|
| 536 |
+
msg = Printer()
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def package_to_hub(
|
| 540 |
+
repo_id,
|
| 541 |
+
model,
|
| 542 |
+
hyperparameters,
|
| 543 |
+
eval_env,
|
| 544 |
+
video_fps=30,
|
| 545 |
+
commit_message="Push agent to the Hub",
|
| 546 |
+
token=None,
|
| 547 |
+
logs=None,
|
| 548 |
+
):
|
| 549 |
+
"""
|
| 550 |
+
Evaluate, Generate a video and Upload a model to Hugging Face Hub.
|
| 551 |
+
This method does the complete pipeline:
|
| 552 |
+
- It evaluates the model
|
| 553 |
+
- It generates the model card
|
| 554 |
+
- It generates a replay video of the agent
|
| 555 |
+
- It pushes everything to the hub
|
| 556 |
+
:param repo_id: id of the model repository from the Hugging Face Hub
|
| 557 |
+
:param model: trained model
|
| 558 |
+
:param eval_env: environment used to evaluate the agent
|
| 559 |
+
:param fps: number of fps for rendering the video
|
| 560 |
+
:param commit_message: commit message
|
| 561 |
+
:param logs: directory on local machine of tensorboard logs you'd like to upload
|
| 562 |
+
"""
|
| 563 |
+
msg.info(
|
| 564 |
+
"This function will save, evaluate, generate a video of your agent, "
|
| 565 |
+
"create a model card and push everything to the hub. "
|
| 566 |
+
"It might take up to 1min. \n "
|
| 567 |
+
"This is a work in progress: if you encounter a bug, please open an issue."
|
| 568 |
+
)
|
| 569 |
+
# Step 1: Clone or create the repo
|
| 570 |
+
repo_url = HfApi().create_repo(
|
| 571 |
+
repo_id=repo_id,
|
| 572 |
+
token=token,
|
| 573 |
+
private=False,
|
| 574 |
+
exist_ok=True,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 578 |
+
tmpdirname = Path(tmpdirname)
|
| 579 |
+
|
| 580 |
+
# Step 2: Save the model
|
| 581 |
+
torch.save(model.state_dict(), tmpdirname / "model.pt")
|
| 582 |
+
|
| 583 |
+
# Step 3: Evaluate the model and build JSON
|
| 584 |
+
mean_reward, std_reward = _evaluate_agent(eval_env, 10, model)
|
| 585 |
+
|
| 586 |
+
# First get datetime
|
| 587 |
+
eval_datetime = datetime.datetime.now()
|
| 588 |
+
eval_form_datetime = eval_datetime.isoformat()
|
| 589 |
+
|
| 590 |
+
evaluate_data = {
|
| 591 |
+
"env_id": hyperparameters.env_id,
|
| 592 |
+
"mean_reward": mean_reward,
|
| 593 |
+
"std_reward": std_reward,
|
| 594 |
+
"n_evaluation_episodes": 10,
|
| 595 |
+
"eval_datetime": eval_form_datetime,
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
# Write a JSON file
|
| 599 |
+
with open(tmpdirname / "results.json", "w") as outfile:
|
| 600 |
+
json.dump(evaluate_data, outfile)
|
| 601 |
+
|
| 602 |
+
# Step 4: Generate a video
|
| 603 |
+
video_path = tmpdirname / "replay.mp4"
|
| 604 |
+
record_video(eval_env, model, video_path, video_fps)
|
| 605 |
+
|
| 606 |
+
# Step 5: Generate the model card
|
| 607 |
+
generated_model_card, metadata = _generate_model_card(
|
| 608 |
+
"PPO",
|
| 609 |
+
hyperparameters.env_id,
|
| 610 |
+
mean_reward,
|
| 611 |
+
std_reward,
|
| 612 |
+
hyperparameters,
|
| 613 |
+
)
|
| 614 |
+
_save_model_card(tmpdirname, generated_model_card, metadata)
|
| 615 |
+
|
| 616 |
+
# Step 6: Add logs if needed
|
| 617 |
+
if logs:
|
| 618 |
+
_add_logdir(tmpdirname, Path(logs))
|
| 619 |
+
|
| 620 |
+
msg.info(f"Pushing repo {repo_id} to the Hugging Face Hub")
|
| 621 |
+
|
| 622 |
+
repo_url = upload_folder(
|
| 623 |
+
repo_id=repo_id,
|
| 624 |
+
folder_path=tmpdirname,
|
| 625 |
+
path_in_repo="",
|
| 626 |
+
commit_message=commit_message,
|
| 627 |
+
token=token,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
msg.info(
|
| 631 |
+
f"Your model is pushed to the Hub. You can view your model here: {repo_url}"
|
| 632 |
+
)
|
| 633 |
+
return repo_url
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _evaluate_agent(env, n_eval_episodes, policy):
|
| 637 |
+
"""
|
| 638 |
+
Evaluate the agent for ``n_eval_episodes`` episodes and returns average reward and std of reward.
|
| 639 |
+
:param env: The evaluation environment
|
| 640 |
+
:param n_eval_episodes: Number of episode to evaluate the agent
|
| 641 |
+
:param policy: The agent
|
| 642 |
+
"""
|
| 643 |
+
episode_rewards = []
|
| 644 |
+
for episode in range(n_eval_episodes):
|
| 645 |
+
state = env.reset()
|
| 646 |
+
step = 0
|
| 647 |
+
done = False
|
| 648 |
+
total_rewards_ep = 0
|
| 649 |
+
|
| 650 |
+
while done is False:
|
| 651 |
+
state = torch.Tensor(state).to(device)
|
| 652 |
+
action, _, _, _ = policy.get_action_and_value(state)
|
| 653 |
+
new_state, reward, done, info = env.step(action.cpu().numpy())
|
| 654 |
+
total_rewards_ep += reward
|
| 655 |
+
if done:
|
| 656 |
+
break
|
| 657 |
+
state = new_state
|
| 658 |
+
episode_rewards.append(total_rewards_ep)
|
| 659 |
+
mean_reward = np.mean(episode_rewards)
|
| 660 |
+
std_reward = np.std(episode_rewards)
|
| 661 |
+
|
| 662 |
+
return mean_reward, std_reward
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def record_video(env, policy, out_directory, fps=30):
|
| 666 |
+
images = []
|
| 667 |
+
done = False
|
| 668 |
+
state = env.reset()
|
| 669 |
+
img = env.render(mode="rgb_array")
|
| 670 |
+
images.append(img)
|
| 671 |
+
while not done:
|
| 672 |
+
state = torch.Tensor(state).to(device)
|
| 673 |
+
# Take the action (index) that have the maximum expected future reward given that state
|
| 674 |
+
action, _, _, _ = policy.get_action_and_value(state)
|
| 675 |
+
state, reward, done, info = env.step(
|
| 676 |
+
action.cpu().numpy()
|
| 677 |
+
) # We directly put next_state = state for recording logic
|
| 678 |
+
img = env.render(mode="rgb_array")
|
| 679 |
+
images.append(img)
|
| 680 |
+
imageio.mimsave(
|
| 681 |
+
out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def _generate_model_card(
|
| 686 |
+
model_name, env_id, mean_reward, std_reward, hyperparameters
|
| 687 |
+
):
|
| 688 |
+
"""
|
| 689 |
+
Generate the model card for the Hub
|
| 690 |
+
:param model_name: name of the model
|
| 691 |
+
:env_id: name of the environment
|
| 692 |
+
:mean_reward: mean reward of the agent
|
| 693 |
+
:std_reward: standard deviation of the mean reward of the agent
|
| 694 |
+
:hyperparameters: training arguments
|
| 695 |
+
"""
|
| 696 |
+
# Step 1: Select the tags
|
| 697 |
+
metadata = generate_metadata(model_name, env_id, mean_reward, std_reward)
|
| 698 |
+
|
| 699 |
+
# Transform the hyperparams namespace to string
|
| 700 |
+
converted_dict = vars(hyperparameters)
|
| 701 |
+
converted_str = str(converted_dict)
|
| 702 |
+
converted_str = converted_str.split(", ")
|
| 703 |
+
converted_str = "\n".join(converted_str)
|
| 704 |
+
|
| 705 |
+
# Step 2: Generate the model card
|
| 706 |
+
model_card = f"""
|
| 707 |
+
# PPO Agent Playing {env_id}
|
| 708 |
+
|
| 709 |
+
This is a trained model of a PPO agent playing {env_id}.
|
| 710 |
+
|
| 711 |
+
# Hyperparameters
|
| 712 |
+
```python
|
| 713 |
+
{converted_str}
|
| 714 |
+
```
|
| 715 |
+
"""
|
| 716 |
+
return model_card, metadata
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def generate_metadata(model_name, env_id, mean_reward, std_reward):
|
| 720 |
+
"""
|
| 721 |
+
Define the tags for the model card
|
| 722 |
+
:param model_name: name of the model
|
| 723 |
+
:param env_id: name of the environment
|
| 724 |
+
:mean_reward: mean reward of the agent
|
| 725 |
+
:std_reward: standard deviation of the mean reward of the agent
|
| 726 |
+
"""
|
| 727 |
+
metadata = {}
|
| 728 |
+
metadata["tags"] = [
|
| 729 |
+
env_id,
|
| 730 |
+
"ppo",
|
| 731 |
+
"deep-reinforcement-learning",
|
| 732 |
+
"reinforcement-learning",
|
| 733 |
+
"custom-implementation",
|
| 734 |
+
"deep-rl-course",
|
| 735 |
+
]
|
| 736 |
+
|
| 737 |
+
# Add metrics
|
| 738 |
+
eval = metadata_eval_result(
|
| 739 |
+
model_pretty_name=model_name,
|
| 740 |
+
task_pretty_name="reinforcement-learning",
|
| 741 |
+
task_id="reinforcement-learning",
|
| 742 |
+
metrics_pretty_name="mean_reward",
|
| 743 |
+
metrics_id="mean_reward",
|
| 744 |
+
metrics_value=f"{mean_reward:.2f} +/- {std_reward:.2f}",
|
| 745 |
+
dataset_pretty_name=env_id,
|
| 746 |
+
dataset_id=env_id,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
# Merges both dictionaries
|
| 750 |
+
metadata = {**metadata, **eval}
|
| 751 |
+
|
| 752 |
+
return metadata
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _save_model_card(local_path, generated_model_card, metadata):
|
| 756 |
+
"""Saves a model card for the repository.
|
| 757 |
+
:param local_path: repository directory
|
| 758 |
+
:param generated_model_card: model card generated by _generate_model_card()
|
| 759 |
+
:param metadata: metadata
|
| 760 |
+
"""
|
| 761 |
+
readme_path = local_path / "README.md"
|
| 762 |
+
readme = ""
|
| 763 |
+
if readme_path.exists():
|
| 764 |
+
with readme_path.open("r", encoding="utf8") as f:
|
| 765 |
+
readme = f.read()
|
| 766 |
+
else:
|
| 767 |
+
readme = generated_model_card
|
| 768 |
+
|
| 769 |
+
with readme_path.open("w", encoding="utf-8") as f:
|
| 770 |
+
f.write(readme)
|
| 771 |
+
|
| 772 |
+
# Save our metrics to Readme metadata
|
| 773 |
+
metadata_save(readme_path, metadata)
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def _add_logdir(local_path: Path, logdir: Path):
|
| 777 |
+
"""Adds a logdir to the repository.
|
| 778 |
+
:param local_path: repository directory
|
| 779 |
+
:param logdir: logdir directory
|
| 780 |
+
"""
|
| 781 |
+
if logdir.exists() and logdir.is_dir():
|
| 782 |
+
# Add the logdir to the repository under new dir called logs
|
| 783 |
+
repo_logdir = local_path / "logs"
|
| 784 |
+
|
| 785 |
+
# Delete current logs if they exist
|
| 786 |
+
if repo_logdir.exists():
|
| 787 |
+
shutil.rmtree(repo_logdir)
|
| 788 |
+
|
| 789 |
+
# Copy logdir into repo logdir
|
| 790 |
+
shutil.copytree(logdir, repo_logdir)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "lunar2"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Training a model on Python 3.10 for HF Hub"
|
| 5 |
+
authors = ["Lois Fernandez <loisfernandezcmp@gmail.com>"]
|
| 6 |
+
|
| 7 |
+
[tool.poetry.dependencies]
|
| 8 |
+
# Restricting to Python 3.10 specifically
|
| 9 |
+
python = "~3.10"
|
| 10 |
+
|
| 11 |
+
# Core ML stack
|
| 12 |
+
torch = "^2.0.0"
|
| 13 |
+
transformers = {extras = ["torch"], version = "^4.35.0"}
|
| 14 |
+
datasets = "^2.15.0"
|
| 15 |
+
accelerate = "^0.24.0"
|
| 16 |
+
|
| 17 |
+
# For uploading to the Hugging Face Hub
|
| 18 |
+
huggingface-hub = "^0.19.0"
|
| 19 |
+
gymnasium = {version = "0.28.1", extras = ["box2d"]}
|
| 20 |
+
|
| 21 |
+
[build-system]
|
| 22 |
+
requires = ["poetry-core>=1.0.0"]
|
| 23 |
+
build-backend = "poetry.core.masonry.api"
|
ruff.toml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude a variety of commonly ignored directories.
|
| 2 |
+
exclude = [
|
| 3 |
+
".bzr",
|
| 4 |
+
".direnv",
|
| 5 |
+
".eggs",
|
| 6 |
+
".git",
|
| 7 |
+
".git-rewrite",
|
| 8 |
+
".hg",
|
| 9 |
+
".ipynb_checkpoints",
|
| 10 |
+
".mypy_cache",
|
| 11 |
+
".nox",
|
| 12 |
+
".pants.d",
|
| 13 |
+
".pyenv",
|
| 14 |
+
".pytest_cache",
|
| 15 |
+
".pytype",
|
| 16 |
+
".ruff_cache",
|
| 17 |
+
".svn",
|
| 18 |
+
".tox",
|
| 19 |
+
".venv",
|
| 20 |
+
".vscode",
|
| 21 |
+
"__pypackages__",
|
| 22 |
+
"_build",
|
| 23 |
+
"buck-out",
|
| 24 |
+
"build",
|
| 25 |
+
"dist",
|
| 26 |
+
"node_modules",
|
| 27 |
+
"site-packages",
|
| 28 |
+
"venv",
|
| 29 |
+
"*/test/*",
|
| 30 |
+
"*/tests/*",
|
| 31 |
+
"*/tools/*",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Same as Black.
|
| 35 |
+
line-length = 79
|
| 36 |
+
indent-width = 4
|
| 37 |
+
|
| 38 |
+
# Assume Python 3.10
|
| 39 |
+
target-version = "py310"
|
| 40 |
+
|
| 41 |
+
[lint]
|
| 42 |
+
# Enable these specific rules and rule-subsets.
|
| 43 |
+
select = [
|
| 44 |
+
"E4",
|
| 45 |
+
"E7",
|
| 46 |
+
"E9",
|
| 47 |
+
"F",
|
| 48 |
+
"ERA",
|
| 49 |
+
"FAST",
|
| 50 |
+
"ANN",
|
| 51 |
+
"ASYNC",
|
| 52 |
+
"S303",
|
| 53 |
+
"S304",
|
| 54 |
+
"FBT",
|
| 55 |
+
"B",
|
| 56 |
+
"A",
|
| 57 |
+
"COM818",
|
| 58 |
+
"C4",
|
| 59 |
+
"FA",
|
| 60 |
+
"ICN",
|
| 61 |
+
"PIE",
|
| 62 |
+
"Q",
|
| 63 |
+
"RET",
|
| 64 |
+
"SIM",
|
| 65 |
+
"ARG",
|
| 66 |
+
"PTH",
|
| 67 |
+
"FLY",
|
| 68 |
+
"C90",
|
| 69 |
+
"NPY",
|
| 70 |
+
"PD",
|
| 71 |
+
"PLE03",
|
| 72 |
+
"D101",
|
| 73 |
+
"D102",
|
| 74 |
+
"D103",
|
| 75 |
+
"D105",
|
| 76 |
+
"D107",
|
| 77 |
+
"D201",
|
| 78 |
+
"D403",
|
| 79 |
+
"D419"
|
| 80 |
+
]
|
| 81 |
+
ignore = ["ANN101", "ANN102","ANN002", "ANN003", "ANN401", "PIE790", "RET504"]
|
| 82 |
+
|
| 83 |
+
# Allow fix for all enabled rules (when `--fix`) is provided.
|
| 84 |
+
fixable = ["ALL"]
|
| 85 |
+
unfixable = []
|
| 86 |
+
|
| 87 |
+
# Allow unused variables when underscore-prefixed.
|
| 88 |
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
| 89 |
+
|
| 90 |
+
[lint.per-file-ignores]
|
| 91 |
+
"__init__.py" = ["F401"]
|
| 92 |
+
|
| 93 |
+
[format]
|
| 94 |
+
# Like Black, use double quotes for strings.
|
| 95 |
+
quote-style = "double"
|
| 96 |
+
|
| 97 |
+
# Like Black, indent with spaces, rather than tabs.
|
| 98 |
+
indent-style = "space"
|
| 99 |
+
|
| 100 |
+
# Like Black, respect magic trailing commas.
|
| 101 |
+
skip-magic-trailing-comma = false
|
| 102 |
+
|
| 103 |
+
# Like Black, automatically detect the appropriate line ending.
|
| 104 |
+
line-ending = "auto"
|
runs/CartPole-v1_ppo_1_1776758028/events.out.tfevents.1776758028.BL-LFERNANDEZ.local.eurecat.org.10534.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9afa36760505e109225f5bbf594139ad908be8c8f69f10b1ef0e19c3837e4a32
|
| 3 |
+
size 4505
|
runs/CartPole-v1_ppo_1_1776759190/events.out.tfevents.1776759190.BL-LFERNANDEZ.local.eurecat.org.13333.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:316b4b1d7aa14b1ea6a647b2beca1a6538747c49ec61f8b8066bbeba867e431c
|
| 3 |
+
size 4505
|
runs/CartPole-v1_ppo_1_1776759930/events.out.tfevents.1776759930.BL-LFERNANDEZ.local.eurecat.org.16675.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:782ac37f9222ea7332327d7ba3f67dae23036fbbedfddcdef39f55160dab559e
|
| 3 |
+
size 4646
|
runs/CartPole-v1_ppo_1_1776760089/events.out.tfevents.1776760089.BL-LFERNANDEZ.local.eurecat.org.17273.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bccf2c3eea4f07985090378e500d218d96ec681ada92a09b4a9ee726a233091c
|
| 3 |
+
size 4505
|
runs/CartPole-v1_ppo_1_1776760573/events.out.tfevents.1776760573.BL-LFERNANDEZ.local.eurecat.org.19521.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03821d01fb60402b8552445a0198fbb23ea001cf127dd5bdaa3692d968fcff93
|
| 3 |
+
size 4518
|
runs/CartPole-v1_ppo_1_1776760603/events.out.tfevents.1776760800.BL-LFERNANDEZ.local.eurecat.org.19762.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b556eab8c87d3450e4b7e3d8b474d2181818544a1d345d5d9baeac21dbec3022
|
| 3 |
+
size 5084
|
runs/CartPole-v1_ppo_1_1776761629/events.out.tfevents.1776761629.BL-LFERNANDEZ.local.eurecat.org.40907.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c874ddd1e6790112cb2546a418e40ad0444439dd6c1cd2f92237416a779759a8
|
| 3 |
+
size 4615
|
runs/CartPole-v1_ppo_1_1776764954/events.out.tfevents.1776764954.BL-LFERNANDEZ.local.eurecat.org.45905.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05a6c46f9c4c3503dd8dbc83d546033ed94f2a32a6269c270b1d0f7ec5d1a3e6
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765078/events.out.tfevents.1776765078.BL-LFERNANDEZ.local.eurecat.org.46277.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da4a24a12acecc6744ffec1840df82d7d0a43e52af6e5508049146dd8d01ca37
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765121/events.out.tfevents.1776765121.BL-LFERNANDEZ.local.eurecat.org.46547.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c828152286330150089d2a3aacb183e30671ddb0128011b758cb1e635a08598
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765317/events.out.tfevents.1776765317.BL-LFERNANDEZ.local.eurecat.org.47512.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2972bec30717f8de533164b54173c35d83ff6bb92090c9972fee9be0d2586536
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765438/events.out.tfevents.1776765438.BL-LFERNANDEZ.local.eurecat.org.47896.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb9a29012db8c3f257843008c0b6fe5eef5838c00232ea123a25b3fb6fbbe6a1
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765498/events.out.tfevents.1776765498.BL-LFERNANDEZ.local.eurecat.org.48020.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60d1d932ada826bcd927f5fea551571925735ce2151711fcf01557d4d2fb10c3
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765547/events.out.tfevents.1776765547.BL-LFERNANDEZ.local.eurecat.org.48309.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7bed5acfede6c85caf003f79714c91c2d43b2a46bc61ed935565435406348d35
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765580/events.out.tfevents.1776765580.BL-LFERNANDEZ.local.eurecat.org.48524.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf8a7c7dc0e56bc55a48ba5cb5574c7395b57948b3134839937b32fe27742727
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776765943/events.out.tfevents.1776765945.BL-LFERNANDEZ.local.eurecat.org.49910.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe38d20ce1da58e3db7c61d6e468f058e96899863695087cdf1342a10ff67936
|
| 3 |
+
size 386
|
runs/CartPole-v1_ppo_1_1776766122/events.out.tfevents.1776766124.BL-LFERNANDEZ.local.eurecat.org.50518.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6982a170b017fbca8e61516594b134ddacbe16dcb14fbb58a8addca9e036da11
|
| 3 |
+
size 386
|
runs/CartPole-v1_ppo_1_1776766281/events.out.tfevents.1776766283.BL-LFERNANDEZ.local.eurecat.org.51127.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:618af812c17147f953e1e421178ad38fb7c7d50e52e322e77c8e3dd9cec8ce20
|
| 3 |
+
size 386
|
runs/CartPole-v1_ppo_1_1776766423/events.out.tfevents.1776766423.BL-LFERNANDEZ.local.eurecat.org.51653.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e42b720566506a92475a44ef430dbb18a65b5ce7389e9d2f85ccc019ca218fb2
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776766445/events.out.tfevents.1776766445.BL-LFERNANDEZ.local.eurecat.org.51862.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:544211ff4e7dccf156b57657f28c98e9047e5ee09d242efb47211d1dcf6bcf6a
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776767609/events.out.tfevents.1776767609.BL-LFERNANDEZ.local.eurecat.org.55609.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18697544a26000a565d90b8c95bec475ec6037305456213f73867826b558bb8f
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776767664/events.out.tfevents.1776767664.BL-LFERNANDEZ.local.eurecat.org.55869.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:721804f39bacafaa16df395e991a020282a21377b0c4a85d5d65fabe2c146d49
|
| 3 |
+
size 387
|
runs/CartPole-v1_ppo_1_1776767691/events.out.tfevents.1776767691.BL-LFERNANDEZ.local.eurecat.org.55998.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5a8c1b7da7653289620f6915d27331744b2b5de6abde50d15052b9c0eecf3ea
|
| 3 |
+
size 400
|
runs/CartPole-v1_ppo_1_1776767743/events.out.tfevents.1776767743.BL-LFERNANDEZ.local.eurecat.org.56120.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7410d303a58d15dfa53fd49591350c0e6d746ab775dda23a7dbeb86d5ceac169
|
| 3 |
+
size 438
|
runs/CartPole-v1_ppo_1_1776767808/events.out.tfevents.1776767808.BL-LFERNANDEZ.local.eurecat.org.56378.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6dea466e45090c43781b14cd8a849364bc234c6cf14714b8a4865bafed15c89
|
| 3 |
+
size 438
|
runs/CartPole-v1_ppo_1_1776767863/events.out.tfevents.1776767863.BL-LFERNANDEZ.local.eurecat.org.56483.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:daecbdc1192a33fc47cbdcf5cebe2a7c73bb9c27376705e35835a1bae7e59768
|
| 3 |
+
size 438
|
runs/CartPole-v1_ppo_1_1776768348/events.out.tfevents.1776768348.BL-LFERNANDEZ.local.eurecat.org.57900.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42db5c74403b1efa7b85919ffcbbb9090373fe612df1fb272719837d4ad24d4a
|
| 3 |
+
size 438
|
runs/CartPole-v1_ppo_1_1776768490/events.out.tfevents.1776768490.BL-LFERNANDEZ.local.eurecat.org.58209.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ed66b6a9f6ce678547c16073df52d6cfd5c5d1bd5f85f060df01e71ce223ed9
|
| 3 |
+
size 455
|
runs/CartPole-v1_ppo_1_1776768658/events.out.tfevents.1776768658.BL-LFERNANDEZ.local.eurecat.org.58694.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6705c8526854d2808b45af2747bff6840ef0998c53234d6eb99dd9e7329c79ef
|
| 3 |
+
size 455
|
runs/CartPole-v1_ppo_1_1776768770/events.out.tfevents.1776768770.BL-LFERNANDEZ.local.eurecat.org.59216.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cba2f3486e16cd82a5c35b88c9392b0d0743d4ac84958dd37d1169855f57f1e
|
| 3 |
+
size 455
|
runs/CartPole-v1_ppo_1_1776768821/events.out.tfevents.1776768821.BL-LFERNANDEZ.local.eurecat.org.59474.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:812c752d707e26387dc00a0c9b9c4153b88d4a6665813068336e8d87c0c62fe7
|
| 3 |
+
size 455
|
runs/CartPole-v1_ppo_1_1776769606/events.out.tfevents.1776769606.BL-LFERNANDEZ.local.eurecat.org.62559.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96c693f2d3db465b461568c35031823ae50200e5ebfd0905f7073bc4c73c5e58
|
| 3 |
+
size 472
|
runs/CartPole-v1_ppo_1_1776769691/events.out.tfevents.1776769691.BL-LFERNANDEZ.local.eurecat.org.62771.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95ba12311c3d22a90a0a994044ac36b7604e9701a796d03ed7eeac3e358678f9
|
| 3 |
+
size 472
|
runs/CartPole-v1_ppo_1_1776770441/events.out.tfevents.1776770441.BL-LFERNANDEZ.local.eurecat.org.64367.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a34c5c92ebefdfe16a1abecd6fa0827772eed1522bdd5ed9eac3ec08033d199e
|
| 3 |
+
size 472
|
runs/CartPole-v1_ppo_1_1776770456/events.out.tfevents.1776770456.BL-LFERNANDEZ.local.eurecat.org.64462.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95d7c0875736b536bf71649f27bf8af2205663198c362e7c09ba0db53d3a53e1
|
| 3 |
+
size 129276
|
runs/CartPole-v1_ppo_1_1776779529/events.out.tfevents.1776779529.BL-LFERNANDEZ.local.eurecat.org.84023.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2c15bf74d464fba25c98d4c8b5e0db41a073f997d78f72526a002979274917a
|
| 3 |
+
size 80573
|
runs/CartPole-v1_ppo_1_1776779596/events.out.tfevents.1776779596.BL-LFERNANDEZ.local.eurecat.org.106839.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42220ac9036b3328c2fde82e3e4115e5e6cc70475fe9ecfb7606272b7df61cc9
|
| 3 |
+
size 8246
|
runs/CartPole-v1_ppo_1_1776779670/events.out.tfevents.1776779672.BL-LFERNANDEZ.local.eurecat.org.209000.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2eaca50079d50ce8a79530f7267e983cbdd1e253749c8aee2e57230e6894ff9b
|
| 3 |
+
size 688
|
runs/CartPole-v1_ppo_1_1776779742/events.out.tfevents.1776779745.BL-LFERNANDEZ.local.eurecat.org.209486.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf0bd5e4e5ca016f6f0f843bcf53caec5ac7738c358312f90e988b79c6ff5c97
|
| 3 |
+
size 688
|
runs/CartPole-v1_ppo_1_1776779888/events.out.tfevents.1776779890.BL-LFERNANDEZ.local.eurecat.org.210132.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8388b360350a0ae8c2c01d265393be1f572f5a175c57d4ad14569208ba8869f1
|
| 3 |
+
size 688
|